xinyu1205 commited on
Commit
616e7e7
·
1 Parent(s): 83e5677

Upload with huggingface_hub

Browse files
.ipynb_checkpoints/gradio_demo-checkpoint.ipynb ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "35d8939e-909d-45d8-bcf9-0ff1dccacfdf",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['bert.encoder.layer.6.output.LayerNorm.weight', 'bert.encoder.layer.6.attention.self.query.weight', 'bert.encoder.layer.3.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.self.value.bias', 'bert.encoder.layer.2.attention.self.value.bias', 'bert.encoder.layer.10.intermediate.dense.bias', 'bert.encoder.layer.3.intermediate.dense.bias', 'bert.encoder.layer.6.attention.self.value.weight', 'bert.encoder.layer.11.output.dense.bias', 'bert.encoder.layer.3.attention.self.value.bias', 'bert.encoder.layer.7.attention.self.value.bias', 'bert.encoder.layer.2.attention.output.dense.weight', 'bert.encoder.layer.11.attention.output.dense.weight', 'bert.encoder.layer.6.output.dense.bias', 'bert.encoder.layer.6.attention.output.dense.bias', 'bert.encoder.layer.4.output.LayerNorm.weight', 'bert.encoder.layer.9.output.dense.weight', 'bert.encoder.layer.9.attention.self.key.bias', 'bert.encoder.layer.3.attention.self.key.weight', 'bert.encoder.layer.3.intermediate.dense.weight', 'bert.encoder.layer.8.output.LayerNorm.weight', 'cls.seq_relationship.bias', 'bert.encoder.layer.6.attention.self.value.bias', 'bert.encoder.layer.10.output.LayerNorm.bias', 'bert.encoder.layer.10.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.attention.self.key.bias', 'bert.encoder.layer.3.attention.self.query.weight', 'bert.encoder.layer.8.intermediate.dense.weight', 'bert.encoder.layer.8.attention.output.LayerNorm.bias', 'bert.encoder.layer.7.attention.output.dense.weight', 'bert.encoder.layer.9.attention.self.query.bias', 'bert.encoder.layer.2.output.dense.bias', 'bert.encoder.layer.6.attention.self.key.bias', 'bert.encoder.layer.4.attention.self.query.weight', 'bert.encoder.layer.2.attention.self.query.weight', 'bert.encoder.layer.11.attention.self.query.weight', 'bert.encoder.layer.3.attention.output.dense.weight', 'bert.encoder.layer.11.attention.output.LayerNorm.bias', 'bert.encoder.layer.10.attention.self.key.weight', 'bert.encoder.layer.3.attention.self.value.weight', 'bert.encoder.layer.5.attention.self.key.bias', 'bert.encoder.layer.5.intermediate.dense.bias', 'bert.encoder.layer.7.attention.self.key.weight', 'bert.encoder.layer.5.attention.self.value.weight', 'bert.encoder.layer.2.attention.output.dense.bias', 'bert.encoder.layer.2.output.dense.weight', 'bert.encoder.layer.6.attention.output.dense.weight', 'bert.encoder.layer.2.intermediate.dense.bias', 'bert.encoder.layer.9.attention.self.value.bias', 'bert.encoder.layer.6.intermediate.dense.bias', 'bert.encoder.layer.9.attention.output.dense.bias', 'bert.encoder.layer.7.attention.self.query.weight', 'bert.encoder.layer.8.attention.self.value.bias', 'bert.encoder.layer.4.attention.self.key.bias', 'bert.pooler.dense.bias', 'bert.encoder.layer.10.attention.output.dense.bias', 'bert.encoder.layer.5.output.LayerNorm.weight', 'cls.seq_relationship.weight', 'bert.encoder.layer.11.intermediate.dense.weight', 'bert.encoder.layer.2.attention.self.key.bias', 'bert.encoder.layer.10.attention.output.LayerNorm.weight', 'bert.encoder.layer.10.output.dense.bias', 'bert.encoder.layer.10.intermediate.dense.weight', 'bert.encoder.layer.4.intermediate.dense.weight', 'bert.encoder.layer.3.attention.self.key.bias', 'bert.encoder.layer.5.attention.self.query.weight', 'bert.encoder.layer.9.intermediate.dense.weight', 'bert.pooler.dense.weight', 'bert.encoder.layer.7.attention.output.LayerNorm.bias', 'bert.encoder.layer.11.output.LayerNorm.weight', 'bert.encoder.layer.5.attention.output.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'bert.encoder.layer.10.attention.self.value.bias', 'bert.encoder.layer.4.attention.self.query.bias', 'bert.encoder.layer.3.attention.self.query.bias', 'bert.encoder.layer.10.output.LayerNorm.weight', 'bert.encoder.layer.10.attention.self.key.bias', 'bert.encoder.layer.8.attention.self.value.weight', 'bert.encoder.layer.4.output.dense.bias', 'bert.encoder.layer.7.attention.self.key.bias', 'bert.encoder.layer.8.intermediate.dense.bias', 'bert.encoder.layer.7.intermediate.dense.weight', 'bert.encoder.layer.2.attention.self.key.weight', 'bert.encoder.layer.4.attention.output.dense.bias', 'bert.encoder.layer.6.output.dense.weight', 'bert.encoder.layer.8.attention.output.LayerNorm.weight', 'bert.encoder.layer.11.output.LayerNorm.bias', 'bert.encoder.layer.10.output.dense.weight', 'bert.encoder.layer.4.attention.output.LayerNorm.bias', 'bert.encoder.layer.11.output.dense.weight', 'bert.encoder.layer.8.output.dense.weight', 'bert.encoder.layer.5.attention.self.value.bias', 'bert.encoder.layer.4.intermediate.dense.bias', 'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.4.attention.self.key.weight', 'bert.encoder.layer.7.attention.self.query.bias', 'bert.encoder.layer.10.attention.self.query.weight', 'bert.encoder.layer.5.output.dense.bias', 'bert.encoder.layer.5.attention.output.dense.weight', 'bert.encoder.layer.7.output.dense.bias', 'bert.embeddings.token_type_embeddings.weight', 'bert.encoder.layer.8.output.dense.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.weight', 'bert.encoder.layer.6.attention.self.key.weight', 'bert.encoder.layer.11.attention.output.LayerNorm.weight', 'bert.encoder.layer.7.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.output.LayerNorm.weight', 'bert.encoder.layer.3.output.dense.bias', 'bert.encoder.layer.8.attention.self.query.bias', 'bert.encoder.layer.6.attention.self.query.bias', 'bert.encoder.layer.4.attention.output.dense.weight', 'bert.encoder.layer.6.intermediate.dense.weight', 'bert.encoder.layer.8.attention.output.dense.bias', 'bert.encoder.layer.10.attention.self.query.bias', 'bert.encoder.layer.8.attention.output.dense.weight', 'bert.encoder.layer.9.attention.output.dense.weight', 'bert.encoder.layer.5.output.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'bert.encoder.layer.9.attention.self.query.weight', 'bert.encoder.layer.2.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.self.value.weight', 'bert.encoder.layer.6.output.LayerNorm.bias', 'bert.encoder.layer.10.attention.output.dense.weight', 'bert.encoder.layer.5.attention.self.query.bias', 'bert.encoder.layer.3.output.dense.weight', 'bert.encoder.layer.2.output.LayerNorm.weight', 'bert.encoder.layer.4.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.self.value.weight', 'bert.encoder.layer.6.attention.output.LayerNorm.bias', 'bert.encoder.layer.11.attention.output.dense.bias', 'bert.encoder.layer.2.attention.output.LayerNorm.weight', 'bert.encoder.layer.7.output.LayerNorm.weight', 'bert.encoder.layer.2.output.LayerNorm.bias', 'bert.encoder.layer.3.output.LayerNorm.bias', 'cls.predictions.decoder.weight', 'bert.encoder.layer.5.attention.output.LayerNorm.weight', 'bert.encoder.layer.2.intermediate.dense.weight', 'bert.encoder.layer.11.attention.self.key.weight', 'bert.encoder.layer.11.attention.self.value.weight', 'bert.encoder.layer.9.intermediate.dense.bias', 'bert.encoder.layer.11.intermediate.dense.bias', 'bert.encoder.layer.11.attention.self.key.bias', 'bert.encoder.layer.2.attention.self.value.weight', 'bert.encoder.layer.3.output.LayerNorm.weight', 'bert.encoder.layer.9.output.LayerNorm.bias', 'bert.encoder.layer.5.intermediate.dense.weight', 'bert.encoder.layer.8.output.LayerNorm.bias', 'bert.encoder.layer.9.output.LayerNorm.weight', 'bert.encoder.layer.7.attention.self.value.weight', 'bert.encoder.layer.9.output.dense.bias', 'bert.encoder.layer.7.intermediate.dense.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.weight', 'bert.encoder.layer.8.attention.self.query.weight', 'bert.encoder.layer.9.attention.self.key.weight', 'bert.encoder.layer.4.output.dense.weight', 'bert.encoder.layer.2.attention.self.query.bias', 'bert.encoder.layer.9.attention.output.LayerNorm.bias', 'bert.encoder.layer.3.attention.output.dense.bias', 'bert.encoder.layer.7.output.dense.weight', 'bert.encoder.layer.10.attention.self.value.weight', 'bert.encoder.layer.8.attention.self.key.weight', 'bert.encoder.layer.11.attention.self.value.bias', 'cls.predictions.transform.LayerNorm.bias', 'bert.encoder.layer.3.attention.output.LayerNorm.weight', 'bert.encoder.layer.5.attention.output.dense.bias', 'bert.encoder.layer.4.attention.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.query.bias', 'cls.predictions.transform.dense.bias', 'bert.encoder.layer.7.attention.output.dense.bias', 'bert.encoder.layer.5.output.LayerNorm.bias']\n",
14
+ "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
15
+ "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
16
+ "Some weights of BertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.self.query.bias', 'bert.encoder.layer.1.crossattention.self.key.bias']\n",
17
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
18
+ ]
19
+ },
20
+ {
21
+ "name": "stdout",
22
+ "output_type": "stream",
23
+ "text": [
24
+ "/encoder/layer/0/crossattention/self/query is tied\n",
25
+ "/encoder/layer/0/crossattention/self/key is tied\n",
26
+ "/encoder/layer/0/crossattention/self/value is tied\n",
27
+ "/encoder/layer/0/crossattention/output/dense is tied\n",
28
+ "/encoder/layer/0/crossattention/output/LayerNorm is tied\n",
29
+ "/encoder/layer/0/intermediate/dense is tied\n",
30
+ "/encoder/layer/0/output/dense is tied\n",
31
+ "/encoder/layer/0/output/LayerNorm is tied\n",
32
+ "/encoder/layer/1/crossattention/self/query is tied\n",
33
+ "/encoder/layer/1/crossattention/self/key is tied\n",
34
+ "/encoder/layer/1/crossattention/self/value is tied\n",
35
+ "/encoder/layer/1/crossattention/output/dense is tied\n",
36
+ "/encoder/layer/1/crossattention/output/LayerNorm is tied\n",
37
+ "/encoder/layer/1/intermediate/dense is tied\n",
38
+ "/encoder/layer/1/output/dense is tied\n",
39
+ "/encoder/layer/1/output/LayerNorm is tied\n",
40
+ "--------------\n",
41
+ "/home/notebook/code/personal/S9049611/BLIP/output/blip_tagtotext_14m/blip_tagtotext_encoderdiv_tar_random_swin/caption_coco_finetune_tagparse_tagfinetune_threshold075_bceloss_tagsingle_5e6_epoch19_negative_1_05_pos_1_10/checkpoint_05.pth\n",
42
+ "--------------\n",
43
+ "load checkpoint from /home/notebook/code/personal/S9049611/BLIP/output/blip_tagtotext_14m/blip_tagtotext_encoderdiv_tar_random_swin/caption_coco_finetune_tagparse_tagfinetune_threshold075_bceloss_tagsingle_5e6_epoch19_negative_1_05_pos_1_10/checkpoint_05.pth\n",
44
+ "vit: swin_b\n",
45
+ "msg_v2 _IncompatibleKeys(missing_keys=['visual_encoder.layers.0.blocks.0.attn.relative_position_index', 'visual_encoder.layers.0.blocks.1.attn_mask', 'visual_encoder.layers.0.blocks.1.attn.relative_position_index', 'visual_encoder.layers.1.blocks.0.attn.relative_position_index', 'visual_encoder.layers.1.blocks.1.attn_mask', 'visual_encoder.layers.1.blocks.1.attn.relative_position_index', 'visual_encoder.layers.2.blocks.0.attn.relative_position_index', 'visual_encoder.layers.2.blocks.1.attn_mask', 'visual_encoder.layers.2.blocks.1.attn.relative_position_index', 'visual_encoder.layers.2.blocks.2.attn.relative_position_index', 'visual_encoder.layers.2.blocks.3.attn_mask', 'visual_encoder.layers.2.blocks.3.attn.relative_position_index', 'visual_encoder.layers.2.blocks.4.attn.relative_position_index', 'visual_encoder.layers.2.blocks.5.attn_mask', 'visual_encoder.layers.2.blocks.5.attn.relative_position_index', 'visual_encoder.layers.2.blocks.6.attn.relative_position_index', 'visual_encoder.layers.2.blocks.7.attn_mask', 'visual_encoder.layers.2.blocks.7.attn.relative_position_index', 'visual_encoder.layers.2.blocks.8.attn.relative_position_index', 'visual_encoder.layers.2.blocks.9.attn_mask', 'visual_encoder.layers.2.blocks.9.attn.relative_position_index', 'visual_encoder.layers.2.blocks.10.attn.relative_position_index', 'visual_encoder.layers.2.blocks.11.attn_mask', 'visual_encoder.layers.2.blocks.11.attn.relative_position_index', 'visual_encoder.layers.2.blocks.12.attn.relative_position_index', 'visual_encoder.layers.2.blocks.13.attn_mask', 'visual_encoder.layers.2.blocks.13.attn.relative_position_index', 'visual_encoder.layers.2.blocks.14.attn.relative_position_index', 'visual_encoder.layers.2.blocks.15.attn_mask', 'visual_encoder.layers.2.blocks.15.attn.relative_position_index', 'visual_encoder.layers.2.blocks.16.attn.relative_position_index', 'visual_encoder.layers.2.blocks.17.attn_mask', 'visual_encoder.layers.2.blocks.17.attn.relative_position_index', 'visual_encoder.layers.3.blocks.0.attn.relative_position_index', 'visual_encoder.layers.3.blocks.1.attn.relative_position_index'], unexpected_keys=[])\n"
46
+ ]
47
+ }
48
+ ],
49
+ "source": [
50
+ "from PIL import Image\n",
51
+ "import requests\n",
52
+ "import torch\n",
53
+ "from torchvision import transforms\n",
54
+ "from torchvision.transforms.functional import InterpolationMode\n",
55
+ "import ruamel_yaml as yaml\n",
56
+ "from models.tag2text import tag2text_caption\n",
57
+ "\n",
58
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
59
+ "\n",
60
+ "\n",
61
+ "\n",
62
+ "import gradio as gr\n",
63
+ "\n",
64
+ "image_size = 384\n",
65
+ "\n",
66
+ "\n",
67
+ "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
68
+ " std=[0.229, 0.224, 0.225])\n",
69
+ "transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])\n",
70
+ "\n",
71
+ "\n",
72
+ "\n",
73
+ "#######Swin Version\n",
74
+ "pretrained = '/home/notebook/code/personal/S9049611/BLIP/output/blip_tagtotext_14m/blip_tagtotext_encoderdiv_tar_random_swin/caption_coco_finetune_tagparse_tagfinetune_threshold075_bceloss_tagsingle_5e6_epoch19_negative_1_05_pos_1_10/checkpoint_05.pth'\n",
75
+ "\n",
76
+ "config_file = 'configs/tag2text_caption.yaml'\n",
77
+ "config = yaml.load(open(config_file, 'r'), Loader=yaml.Loader)\n",
78
+ "\n",
79
+ "\n",
80
+ "model = tag2text_caption(pretrained=pretrained, image_size=image_size, vit=config['vit'], \n",
81
+ " vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],\n",
82
+ " prompt=config['prompt'],config=config,threshold = 0.75 )\n",
83
+ "\n",
84
+ "model.eval()\n",
85
+ "model = model.to(device)\n",
86
+ "\n",
87
+ "\n"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": 4,
93
+ "id": "9772dc6f-680d-45a7-b39c-23770eb5258e",
94
+ "metadata": {},
95
+ "outputs": [
96
+ {
97
+ "name": "stdout",
98
+ "output_type": "stream",
99
+ "text": [
100
+ "Running on local URL: http://127.0.0.1:7860\n",
101
+ "Running on public URL: https://202e6e6a-b3d9-4c97.gradio.live\n",
102
+ "\n",
103
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces\n"
104
+ ]
105
+ },
106
+ {
107
+ "data": {
108
+ "text/html": [
109
+ "<div><iframe src=\"https://202e6e6a-b3d9-4c97.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
110
+ ],
111
+ "text/plain": [
112
+ "<IPython.core.display.HTML object>"
113
+ ]
114
+ },
115
+ "metadata": {},
116
+ "output_type": "display_data"
117
+ },
118
+ {
119
+ "data": {
120
+ "text/plain": []
121
+ },
122
+ "execution_count": 4,
123
+ "metadata": {},
124
+ "output_type": "execute_result"
125
+ },
126
+ {
127
+ "name": "stdout",
128
+ "output_type": "stream",
129
+ "text": [
130
+ "<class 'PIL.Image.Image'>\n",
131
+ "<class 'PIL.Image.Image'>\n"
132
+ ]
133
+ }
134
+ ],
135
+ "source": [
136
+ "\n",
137
+ "def inference(raw_image, model_n, input_tag, strategy):\n",
138
+ " if model_n == 'Image Captioning':\n",
139
+ " raw_image = raw_image.resize((image_size, image_size))\n",
140
+ " print(type(raw_image))\n",
141
+ " image = transform(raw_image).unsqueeze(0).to(device) \n",
142
+ " model.threshold = 0.75\n",
143
+ " if input_tag == '' or input_tag == 'none' or input_tag == 'None':\n",
144
+ " input_tag_list = None\n",
145
+ " else:\n",
146
+ " input_tag_list = []\n",
147
+ " input_tag_list.append(input_tag.replace(',',' | '))\n",
148
+ " # print(input_tag_list)\n",
149
+ " with torch.no_grad():\n",
150
+ " if strategy == \"Beam search\":\n",
151
+ " \n",
152
+ "\n",
153
+ " caption, tag_predict = model.generate(image,tag_input = input_tag_list, return_tag_predict = True)\n",
154
+ " if input_tag_list == None:\n",
155
+ " tag_1 = tag_predict\n",
156
+ " tag_2 = ['none']\n",
157
+ " else:\n",
158
+ " _, tag_1 = model.generate(image,tag_input = None, return_tag_predict = True)\n",
159
+ " tag_2 = tag_predict\n",
160
+ "\n",
161
+ " else:\n",
162
+ "\n",
163
+ " caption,tag_predict = model.generate(image, tag_input = input_tag_list,sample=True, top_p=0.9, max_length=20, min_length=5, return_tag_predict = True)\n",
164
+ " if input_tag_list == None:\n",
165
+ " tag_1 = tag_predict\n",
166
+ " tag_2 = ['none']\n",
167
+ " else:\n",
168
+ " _, tag_1 = model.generate(image,tag_input = None, return_tag_predict = True)\n",
169
+ " tag_2 = tag_predict\n",
170
+ " # return 'Caption: '+caption[0], 'Identified Tags:' + tag_predict[0]\n",
171
+ " # return tag_predict[0],caption[0]\n",
172
+ " return tag_1[0],tag_2[0],caption[0]\n",
173
+ " \n",
174
+ " # return 'caption: '+caption[0], tag_predict[0]\n",
175
+ "\n",
176
+ " else: \n",
177
+ " image_vq = transform_vq(raw_image).unsqueeze(0).to(device) \n",
178
+ " with torch.no_grad():\n",
179
+ " answer = model_vq(image_vq, question, train=False, inference='generate') \n",
180
+ " return 'answer: '+answer[0]\n",
181
+ " \n",
182
+ "inputs = [gr.inputs.Image(type='pil'),gr.inputs.Radio(choices=['Image Captioning'], type=\"value\", default=\"Image Captioning\", label=\"Task\"),gr.inputs.Textbox(lines=2, label=\"User Identified Tags (Optional, Enter with commas)\"),gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type=\"value\", default=\"Beam search\", label=\"Caption Decoding Strategy\")]\n",
183
+ "\n",
184
+ "# outputs = gr.outputs.Textbox(label=\"Output\")\n",
185
+ "# outputs = [gr.outputs.Textbox(label=\"Image Caption\"),gr.outputs.Textbox(label=\"Identified Tags\")]\n",
186
+ "outputs = [gr.outputs.Textbox(label=\"Model Identified Tags\"),gr.outputs.Textbox(label=\"User Identified Tags\"), gr.outputs.Textbox(label=\"Image Caption\") ]\n",
187
+ "\n",
188
+ "title = \"Tag2Text\"\n",
189
+ "\n",
190
+ "description = \"Gradio demo for Tag2Text: Guiding Language-Image Model via Image Tagging (Fudan University, OPPO Research Institute, International Digital Economy Academy).\"\n",
191
+ "\n",
192
+ "article = \"<p style='text-align: center'><a href='' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a> | <a href='' target='_blank'>Github Repo</a></p>\"\n",
193
+ "\n",
194
+ "demo = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['images/COCO_val2014_000000551338.jpg',\"Image Captioning\",\"none\",\"Beam search\"], \n",
195
+ " ['images/COCO_val2014_000000551338.jpg',\"Image Captioning\",\"fence, sky\",\"Beam search\"],\n",
196
+ " # ['images/COCO_val2014_000000551338.jpg',\"Image Captioning\",\"grass\",\"Beam search\"],\n",
197
+ " ['images/COCO_val2014_000000483108.jpg',\"Image Captioning\",\"none\",\"Beam search\"],\n",
198
+ " ['images/COCO_val2014_000000483108.jpg',\"Image Captioning\",\"electric cable\",\"Beam search\"],\n",
199
+ " # ['images/COCO_val2014_000000483108.jpg',\"Image Captioning\",\"sky, train\",\"Beam search\"],\n",
200
+ " ['images/COCO_val2014_000000483108.jpg',\"Image Captioning\",\"track, train\",\"Beam search\"] , \n",
201
+ " ['images/COCO_val2014_000000483108.jpg',\"Image Captioning\",\"grass\",\"Beam search\"] \n",
202
+ " ])\n",
203
+ "\n",
204
+ "\n",
205
+ "demo.launch(share=True)"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": null,
211
+ "id": "0da1f11b-e737-47a9-9b07-4e00c0835f63",
212
+ "metadata": {},
213
+ "outputs": [],
214
+ "source": []
215
+ },
216
+ {
217
+ "cell_type": "code",
218
+ "execution_count": null,
219
+ "id": "73a4bb88-4200-4853-b1ba-34f0d4b6dc34",
220
+ "metadata": {},
221
+ "outputs": [],
222
+ "source": []
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "execution_count": null,
227
+ "id": "3340a61f-c6bc-4ead-87ea-b26aa97b7a68",
228
+ "metadata": {},
229
+ "outputs": [],
230
+ "source": []
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "id": "d49e3de4-c3f7-4835-90eb-d0d013fc0ffb",
236
+ "metadata": {},
237
+ "outputs": [],
238
+ "source": []
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "execution_count": null,
243
+ "id": "205e0317-1701-4afd-8d67-bedb6959f350",
244
+ "metadata": {},
245
+ "outputs": [],
246
+ "source": []
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": null,
251
+ "id": "bf5301a5-80c5-4e44-835e-0160a97fef66",
252
+ "metadata": {},
253
+ "outputs": [],
254
+ "source": []
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": null,
259
+ "id": "f63d7a06-7625-4e1c-855d-177971217a0d",
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": []
263
+ },
264
+ {
265
+ "cell_type": "code",
266
+ "execution_count": null,
267
+ "id": "c929e566-1a6e-4280-96eb-c434ef9a35d0",
268
+ "metadata": {},
269
+ "outputs": [],
270
+ "source": []
271
+ }
272
+ ],
273
+ "metadata": {
274
+ "kernelspec": {
275
+ "display_name": "Python 3 (ipykernel)",
276
+ "language": "python",
277
+ "name": "python3"
278
+ },
279
+ "language_info": {
280
+ "codemirror_mode": {
281
+ "name": "ipython",
282
+ "version": 3
283
+ },
284
+ "file_extension": ".py",
285
+ "mimetype": "text/x-python",
286
+ "name": "python",
287
+ "nbconvert_exporter": "python",
288
+ "pygments_lexer": "ipython3",
289
+ "version": "3.7.12"
290
+ }
291
+ },
292
+ "nbformat": 4,
293
+ "nbformat_minor": 5
294
+ }
.ipynb_checkpoints/upload-checkpoint.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
README.md CHANGED
@@ -1,13 +1,7 @@
1
- ---
2
- title: Tag2Text
3
- emoji:
4
- colorFrom: gray
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 3.16.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+
2
+
3
+ Welcome to Tag2Text demo! (Fudan University, OPPO Research Institute, International Digital Economy Academy).
4
+
5
+ Upload your image to get the tags and caption of the image. Optional: You can also input specified tags to get the corresponding caption.
6
+
7
+ We are constantly updating this demo.
 
 
 
 
 
 
app.py CHANGED
@@ -14,7 +14,6 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
 
15
  image_size = 384
16
 
17
-
18
  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
19
  std=[0.229, 0.224, 0.225])
20
  transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])
@@ -26,7 +25,6 @@ pretrained = '/home/notebook/code/personal/S9049611/BLIP/output/blip_tagtotext_1
26
  config_file = 'configs/tag2text_caption.yaml'
27
  config = yaml.load(open(config_file, 'r'), Loader=yaml.Loader)
28
 
29
-
30
  model = tag2text_caption(pretrained=pretrained, image_size=image_size, vit=config['vit'],
31
  vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
32
  prompt=config['prompt'],config=config,threshold = 0.75 )
@@ -35,66 +33,41 @@ model.eval()
35
  model = model.to(device)
36
 
37
 
38
- def inference(raw_image, model_n, input_tag, strategy):
39
- if model_n == 'Image Captioning':
40
- raw_image = raw_image.resize((image_size, image_size))
41
- image = transform(raw_image).unsqueeze(0).to(device)
42
- model.threshold = 0.7
43
- if input_tag == '' or input_tag == 'none' or input_tag == 'None':
44
- input_tag_list = None
 
 
 
 
 
 
 
 
 
 
45
  else:
46
- input_tag_list = []
47
- input_tag_list.append(input_tag.replace(',',' | '))
48
- with torch.no_grad():
49
- if strategy == "Beam search":
50
-
51
-
52
- caption, tag_predict = model.generate(image,tag_input = input_tag_list, return_tag_predict = True)
53
- if input_tag_list == None:
54
- tag_1 = tag_predict
55
- tag_2 = ['none']
56
- else:
57
- _, tag_1 = model.generate(image,tag_input = None, return_tag_predict = True)
58
- tag_2 = tag_predict
59
-
60
- else:
61
-
62
- caption,tag_predict = model.generate(image, tag_input = input_tag_list,sample=True, top_p=0.9, max_length=20, min_length=5, return_tag_predict = True)
63
- if input_tag_list == None:
64
- tag_1 = tag_predict
65
- tag_2 = ['none']
66
- else:
67
- _, tag_1 = model.generate(image,tag_input = None, return_tag_predict = True)
68
- tag_2 = tag_predict
69
- return tag_1[0],tag_2[0],caption[0]
70
-
71
-
72
- else:
73
- image_vq = transform_vq(raw_image).unsqueeze(0).to(device)
74
- with torch.no_grad():
75
- answer = model_vq(image_vq, question, train=False, inference='generate')
76
- return 'answer: '+answer[0]
77
-
78
- inputs = [gr.inputs.Image(type='pil'),gr.inputs.Radio(choices=['Image Captioning'], type="value", default="Image Captioning", label="Task"),gr.inputs.Textbox(lines=2, label="User Identified Tags (Optional, Enter with commas)"),gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type="value", default="Beam search", label="Caption Decoding Strategy")]
79
-
80
- outputs = [gr.outputs.Textbox(label="Model Identified Tags"),gr.outputs.Textbox(label="User Identified Tags"), gr.outputs.Textbox(label="Image Caption") ]
81
 
82
- title = "Tag2Text"
83
 
84
- description = "Gradio demo for Tag2Text: Guiding Language-Image Model via Image Tagging (Fudan University, OPPO Research Institute, International Digital Economy Academy)."
85
 
86
- article = "<p style='text-align: center'><a href='' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a> | <a href='' target='_blank'>Github Repo</a></p>"
87
 
88
- demo = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['images/COCO_val2014_000000551338.jpg',"Image Captioning","none","Beam search"],
89
- ['images/COCO_val2014_000000551338.jpg',"Image Captioning","fence, sky","Beam search"],
90
- # ['images/COCO_val2014_000000551338.jpg',"Image Captioning","grass","Beam search"],
91
- ['images/COCO_val2014_000000483108.jpg',"Image Captioning","none","Beam search"],
92
- ['images/COCO_val2014_000000483108.jpg',"Image Captioning","electric cable","Beam search"],
93
- # ['images/COCO_val2014_000000483108.jpg',"Image Captioning","sky, train","Beam search"],
94
- ['images/COCO_val2014_000000483108.jpg',"Image Captioning","track, train","Beam search"] ,
95
- ['images/COCO_val2014_000000483108.jpg',"Image Captioning","grass","Beam search"]
96
- ])
97
 
 
 
98
 
 
99
 
 
 
 
 
100
 
 
14
 
15
  image_size = 384
16
 
 
17
  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
18
  std=[0.229, 0.224, 0.225])
19
  transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])
 
25
  config_file = 'configs/tag2text_caption.yaml'
26
  config = yaml.load(open(config_file, 'r'), Loader=yaml.Loader)
27
 
 
28
  model = tag2text_caption(pretrained=pretrained, image_size=image_size, vit=config['vit'],
29
  vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
30
  prompt=config['prompt'],config=config,threshold = 0.75 )
 
33
  model = model.to(device)
34
 
35
 
36
+ def inference(raw_image, input_tag):
37
+ raw_image = raw_image.resize((image_size, image_size))
38
+
39
+ image = transform(raw_image).unsqueeze(0).to(device)
40
+ model.threshold = 0.69
41
+ if input_tag == '' or input_tag == 'none' or input_tag == 'None':
42
+ input_tag_list = None
43
+ else:
44
+ input_tag_list = []
45
+ input_tag_list.append(input_tag.replace(',',' | '))
46
+ with torch.no_grad():
47
+
48
+
49
+ caption, tag_predict = model.generate(image,tag_input = input_tag_list, return_tag_predict = True)
50
+ if input_tag_list == None:
51
+ tag_1 = tag_predict
52
+ tag_2 = ['none']
53
  else:
54
+ _, tag_1 = model.generate(image,tag_input = None, return_tag_predict = True)
55
+ tag_2 = tag_predict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ return tag_1[0],tag_2[0],caption[0]
58
 
 
59
 
60
+ inputs = [gr.inputs.Image(type='pil'),gr.inputs.Textbox(lines=2, label="User Specified Tags (Optional, Enter with commas)")]
61
 
62
+ outputs = [gr.outputs.Textbox(label="Model Identified Tags"),gr.outputs.Textbox(label="User Specified Tags"), gr.outputs.Textbox(label="Image Caption") ]
 
 
 
 
 
 
 
 
63
 
64
+ title = "Tag2Text"
65
+ description = "Welcome to Tag2Text demo! (Supported by Fudan University, OPPO Research Institute, International Digital Economy Academy) <br/> Upload your image to get the tags and caption of the image. Optional: You can also input specified tags to get the corresponding caption."
66
 
67
+ article = "<p style='text-align: center'><a href='' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a> | <a href='' target='_blank'>Github Repo</a></p>"
68
 
69
+ demo = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['images/COCO_val2014_000000483108.jpg',"none"],
70
+ ['images/COCO_val2014_000000483108.jpg',"electric cable"],
71
+ ['images/COCO_val2014_000000483108.jpg',"track, train"] ,
72
+ ])
73
 
configs/med_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30524,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
configs/q2l_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 4,
15
+ "num_hidden_layers": 2,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30522,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true,
21
+ "add_tag_cross_attention": false
22
+ }
23
+
configs/swin/config_swinB_224.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
3
+ "vision_width": 1024,
4
+ "image_res": 224,
5
+ "window_size": 7,
6
+ "embed_dim": 128,
7
+ "depths": [ 2, 2, 18, 2 ],
8
+ "num_heads": [ 4, 8, 16, 32 ]
9
+ }
10
+
configs/swin/config_swinB_384.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
3
+ "vision_width": 1024,
4
+ "image_res": 384,
5
+ "window_size": 12,
6
+ "embed_dim": 128,
7
+ "depths": [ 2, 2, 18, 2 ],
8
+ "num_heads": [ 4, 8, 16, 32 ]
9
+ }
10
+
configs/swin/config_swinB_480.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
3
+ "vision_width": 1024,
4
+ "image_res": 480,
5
+ "window_size": 15,
6
+ "embed_dim": 128,
7
+ "depths": [ 2, 2, 18, 2 ],
8
+ "num_heads": [ 4, 8, 16, 32 ]
9
+ }
configs/swin/config_swinB_576.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
3
+ "vision_width": 1024,
4
+ "image_res": 576,
5
+ "window_size": 18,
6
+ "embed_dim": 128,
7
+ "depths": [ 2, 2, 18, 2 ],
8
+ "num_heads": [ 4, 8, 16, 32 ]
9
+ }
configs/swin/config_swinB_608.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth",
3
+ "vision_width": 1024,
4
+ "image_res": 608,
5
+ "window_size": 19,
6
+ "embed_dim": 128,
7
+ "depths": [ 2, 2, 18, 2 ],
8
+ "num_heads": [ 4, 8, 16, 32 ]
9
+ }
configs/tag2text_caption.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/home/notebook/data/group/projects/tagging/caption/datasets/public/coco/'
2
+
3
+ ann_root: 'dataset/caption_dataset'
4
+ coco_gt_root: 'dataset/caption_dataset'
5
+
6
+ pretrained: '/home/notebook/code/personal/S9049611/BLIP/output/pretrain_caption_tagtotext_v2_bert_asl'
7
+
8
+ # size of vit model; base or large
9
+ vit: 'swin_b'
10
+ vit_grad_ckpt: False
11
+ vit_ckpt_layer: 0
12
+
13
+ batch_size: 35
14
+ init_lr: 5e-6
15
+
16
+ image_size: 384
17
+
18
+ # generation configs
19
+ max_length: 20
20
+ min_length: 5
21
+ num_beams: 3
22
+ prompt: 'a picture of '
23
+
24
+ # optimizer
25
+ weight_decay: 0.05
26
+ min_lr: 0
27
+ max_epoch: 10
28
+
29
+ text_pretrain: 'bert'
30
+
31
+ class_num: 3429
32
+ threshold: 0.7
33
+
data/__pycache__/tag_class.cpython-37.pyc ADDED
Binary file (52 kB). View file
 
data/tag_class.py ADDED
@@ -0,0 +1,3437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ tra_array = ['tennis',
5
+ 'bear cub',
6
+ 'observatory',
7
+ 'bicycle',
8
+ 'hillside',
9
+ 'judge',
10
+ 'watercolor illustration',
11
+ 'granite',
12
+ 'lobster',
13
+ 'livery',
14
+ 'stone',
15
+ 'ceramic',
16
+ 'ranch',
17
+ 'cloth',
18
+ 'smile',
19
+ 'building',
20
+ 'tattoo',
21
+ 'cricketer',
22
+ 'cheek',
23
+ 'pear',
24
+ 'source',
25
+ 'winter',
26
+ 'surface',
27
+ 'spray',
28
+ 'ceremony',
29
+ 'magic',
30
+ 'curve',
31
+ 'container',
32
+ 'fair',
33
+ 'medicine',
34
+ 'baby',
35
+ 'tennis racquet',
36
+ 'ornament',
37
+ 'bamboo',
38
+ 'duckling',
39
+ 'song',
40
+ 'safari',
41
+ 'team presentation',
42
+ 'daffodil',
43
+ 'cross',
44
+ 'toothpaste',
45
+ 'shield',
46
+ 'fashion model',
47
+ 'capsule',
48
+ 'map',
49
+ 'creek',
50
+ 'glass house',
51
+ 'glass plate',
52
+ 'siding',
53
+ 'corner',
54
+ 'water buffalo',
55
+ 'bison',
56
+ 'figure skater',
57
+ 'diploma',
58
+ 'tire',
59
+ 'race',
60
+ 'cable car',
61
+ 'brain',
62
+ 'gas stove',
63
+ 'soap bubble',
64
+ 'palette',
65
+ 'snowboard',
66
+ 'school child',
67
+ 'trench coat',
68
+ 'monk',
69
+ 'fiber',
70
+ 'kitchen window',
71
+ 'sunglass',
72
+ 'coffee',
73
+ 'security',
74
+ 'strawberry',
75
+ 'penguin',
76
+ 'tree root',
77
+ 'loaf',
78
+ 'engagement ring',
79
+ 'lamb',
80
+ 'vector cartoon illustration',
81
+ 'sandwich',
82
+ 'mountain village',
83
+ 'shape',
84
+ 'charm',
85
+ 'fiction',
86
+ 'knot',
87
+ 'greenhouse',
88
+ 'sushi',
89
+ 'text',
90
+ 'disaster',
91
+ 'trophy',
92
+ 'gang',
93
+ 'strap',
94
+ 'soccer game',
95
+ 'cardinal',
96
+ 'tee',
97
+ 'turtle',
98
+ 'water surface',
99
+ 'grassland',
100
+ 'dolphin',
101
+ 'store',
102
+ 'dirt',
103
+ 'iceberg',
104
+ 'pergola',
105
+ 'farmer market',
106
+ 'publicity portrait',
107
+ 'tote bag',
108
+ 'teenage girl',
109
+ 'view mirror',
110
+ 'session',
111
+ 'commuter',
112
+ 'dressing room',
113
+ 'tricycle',
114
+ 'christmas ball',
115
+ 'headlight',
116
+ 'police',
117
+ 'armchair',
118
+ 'chart',
119
+ 'yacht',
120
+ 'saw',
121
+ 'printer',
122
+ 'rock band',
123
+ 'gingerbread house',
124
+ 'tag',
125
+ 'table lamp',
126
+ 'hockey game',
127
+ 'slope',
128
+ 'font',
129
+ 'wicker basket',
130
+ 'jewelry',
131
+ 'quarter',
132
+ 'software',
133
+ 'weapon',
134
+ 'pin',
135
+ 'worship',
136
+ 'painter',
137
+ 'goal',
138
+ 'morning light',
139
+ 'bike',
140
+ 'baseball bat',
141
+ 'elevator',
142
+ 'cuisine',
143
+ 'sausage',
144
+ 'stunt',
145
+ 'wrestler',
146
+ 'statue',
147
+ 'landing',
148
+ 'pillar',
149
+ 'willow tree',
150
+ 'sea wave',
151
+ 'chicken',
152
+ 'peanut',
153
+ 'muscle',
154
+ 'bob',
155
+ 'tv genre',
156
+ 'bathroom window',
157
+ 'radish',
158
+ 'textile',
159
+ 'pelican',
160
+ 'marketplace',
161
+ 'crest',
162
+ 'elevation map',
163
+ 'gift',
164
+ 'parish',
165
+ 'traffic light',
166
+ 'campfire',
167
+ 'fog',
168
+ 'award winner',
169
+ 'beach ball',
170
+ 'mat',
171
+ 'white house',
172
+ 'plaster',
173
+ 'moped',
174
+ 'football team',
175
+ 'solution',
176
+ 'bicyclist',
177
+ 'bit',
178
+ 'playground',
179
+ 'darkness',
180
+ 'cake',
181
+ 'maple leave',
182
+ 'mold',
183
+ 'cracker',
184
+ 'blueberry',
185
+ 'rubble',
186
+ 'container ship',
187
+ 'pedestrian bridge',
188
+ 'snail',
189
+ 'parrot',
190
+ 'form',
191
+ 'circuit',
192
+ 'highlight',
193
+ 'pickup truck',
194
+ 'koala',
195
+ 'rain',
196
+ 'system',
197
+ 'weather',
198
+ 'raincoat',
199
+ 'soccer team',
200
+ 'windshield',
201
+ 'thunderstorm',
202
+ 'mike',
203
+ 'bird house',
204
+ 'bridge',
205
+ 'grandfather',
206
+ 'restroom',
207
+ 'animation',
208
+ 'wilderness',
209
+ 'clown',
210
+ 'banana',
211
+ 'brown',
212
+ 'braid',
213
+ 'dining room',
214
+ 'kindergarten',
215
+ 'launch event',
216
+ 'purple',
217
+ 'school',
218
+ 'stairwell',
219
+ 'brooch',
220
+ 'movie poster image',
221
+ 'mountain river',
222
+ 'shelf',
223
+ 'wicket',
224
+ 'headboard',
225
+ 'buddha',
226
+ 'flower field',
227
+ 'dugout',
228
+ 'cd',
229
+ 'bald eagle',
230
+ 'lagoon',
231
+ 'seaweed',
232
+ 'agriculture',
233
+ 'emergency service',
234
+ 'maple tree',
235
+ 'parachute',
236
+ 'continent',
237
+ 'amusement park',
238
+ 'remote',
239
+ 'bun',
240
+ 'tackle',
241
+ 'hospital',
242
+ 'garage door',
243
+ 'birthday party',
244
+ 'friendship',
245
+ 'go',
246
+ 'mausoleum',
247
+ 'jeep',
248
+ 'raccoon',
249
+ 'step',
250
+ 'ice hockey team',
251
+ 'cigarette',
252
+ 'lace dress',
253
+ 'forest floor',
254
+ 'mall',
255
+ 'captain',
256
+ 'milk',
257
+ 'golf course',
258
+ 'meal',
259
+ 'picnic table',
260
+ 'sail',
261
+ 'volleyball',
262
+ 'canal',
263
+ 'terrace',
264
+ 'computer desk',
265
+ 'caravan',
266
+ 'hotel',
267
+ 'cheerleader',
268
+ 'nurse',
269
+ 'museum',
270
+ 'marsh',
271
+ 'fox',
272
+ 'plateau',
273
+ 'night',
274
+ 'twin',
275
+ 'letter logo',
276
+ 'autumn tree',
277
+ 'powder',
278
+ 'convention',
279
+ 'creature',
280
+ 'lighthouse',
281
+ 'shop window',
282
+ 'jacket',
283
+ 'stork',
284
+ 'taxi',
285
+ 'trade',
286
+ 'blackboard',
287
+ 'olive',
288
+ 'road sign',
289
+ 'resort',
290
+ 'snowflake',
291
+ 'cemetery',
292
+ 'travel',
293
+ 'evening dress',
294
+ 'picnic',
295
+ 'drink',
296
+ 'winter morning',
297
+ 'football player',
298
+ 'snack',
299
+ 'boxing glove',
300
+ 'dinner party',
301
+ 'airline',
302
+ 'swing',
303
+ 'port',
304
+ 'wheelbarrow',
305
+ 'bathroom sink',
306
+ 'sweater',
307
+ 'ambulance',
308
+ 'gear',
309
+ 'oil',
310
+ 'wii controller',
311
+ 'array',
312
+ 'home office',
313
+ 'car show',
314
+ 'mixture',
315
+ 'profession',
316
+ 'tree frog',
317
+ 'square',
318
+ 'facility',
319
+ 'coral reef',
320
+ 'sea wall',
321
+ 'pizza',
322
+ 'exhibit',
323
+ 'demolition',
324
+ 'trout',
325
+ 'ring',
326
+ 'coffee shop',
327
+ 'bracelet',
328
+ 'bean',
329
+ 'lip',
330
+ 'fencing',
331
+ 'landscape',
332
+ 'sitting',
333
+ 'package',
334
+ 'metal',
335
+ 'bust',
336
+ 'king',
337
+ 'hair',
338
+ 'window seat',
339
+ 'wildlife',
340
+ 'trunk',
341
+ 'greenery',
342
+ 'stencil',
343
+ 'fire hydrant',
344
+ 'bridesmaid',
345
+ 'plaza',
346
+ 'alps',
347
+ 'tower bridge',
348
+ 'crop top',
349
+ 'crossing',
350
+ 'cinema',
351
+ 'pedestrian crossing',
352
+ 'family',
353
+ 'shopping cart',
354
+ 'stomach',
355
+ 'church building',
356
+ 'screen door',
357
+ 'skater',
358
+ 'soccer field',
359
+ 'kettle',
360
+ 'mussel',
361
+ 'raindrop',
362
+ 'candy cane',
363
+ 'water lily',
364
+ 'flower girl',
365
+ 'desert',
366
+ 'enclosure',
367
+ 'christmas light',
368
+ 'kitchen',
369
+ 'caterpillar',
370
+ 'plaid',
371
+ 'bath',
372
+ 'bush',
373
+ 'mud',
374
+ 'ballet',
375
+ 'knee',
376
+ 'adult',
377
+ 'raft',
378
+ 'sea view',
379
+ 'cactus',
380
+ 'office chair',
381
+ 'overall',
382
+ 'rim',
383
+ 'scaffolding',
384
+ 'pig',
385
+ 'cover',
386
+ 'poster page',
387
+ 'sprinkle',
388
+ 'chandelier',
389
+ 'algae',
390
+ 'traffic',
391
+ 'surfboard',
392
+ 'book',
393
+ 'filming',
394
+ 'flash',
395
+ 'mansion',
396
+ 'camouflage',
397
+ 'trouser',
398
+ 'ticket',
399
+ 'weed',
400
+ 'cab',
401
+ 'trench',
402
+ 'elephant',
403
+ 'huddle',
404
+ 'sphere',
405
+ 'christmas decoration',
406
+ 'city',
407
+ 'launch',
408
+ 'doll',
409
+ 'christmas ornament',
410
+ 'fabric',
411
+ 'bikini',
412
+ 'biplane',
413
+ 'breakfast',
414
+ 'neighbourhood',
415
+ 'race track',
416
+ 'foliage',
417
+ 'avocado',
418
+ 'school bus',
419
+ 'footwear',
420
+ 'highway',
421
+ 'ocean view',
422
+ 'art vector illustration',
423
+ 'wall clock',
424
+ 'curtain',
425
+ 'teenager',
426
+ 'kitchen area',
427
+ 'robot',
428
+ 'tusk',
429
+ 'lounge chair',
430
+ 'beam',
431
+ 'paddle',
432
+ 'camel',
433
+ 'lid',
434
+ 'world map',
435
+ 'city view',
436
+ 'newlywed',
437
+ 'cargo ship',
438
+ 'yellow',
439
+ 'exhibition',
440
+ 'bend',
441
+ 'novel',
442
+ 'wool',
443
+ 'ontario',
444
+ 'bread',
445
+ 'campus',
446
+ 'coastline',
447
+ 'cutting board',
448
+ 'booth',
449
+ 'table top',
450
+ 'carpet',
451
+ 'beach chair',
452
+ 'workout',
453
+ 'street food',
454
+ 'fun',
455
+ 'costumer film designer',
456
+ 'gadget',
457
+ 'artist',
458
+ 'fishing village',
459
+ 'builder',
460
+ 'violinist',
461
+ 'iphone',
462
+ 'spider web',
463
+ 'traffic sign',
464
+ 'ruin',
465
+ 'rescue',
466
+ 'clipboard',
467
+ 'seal',
468
+ 'film director',
469
+ 'paw',
470
+ 'nursery',
471
+ 'intersection',
472
+ 'tomato sauce',
473
+ 'taste',
474
+ 'paddy field',
475
+ 'christmas tree',
476
+ 'wave',
477
+ 'stool',
478
+ 'watering can',
479
+ 'rug',
480
+ 'daytime',
481
+ 'subway station',
482
+ 'craft',
483
+ 'pine forest',
484
+ 'black',
485
+ 'planet',
486
+ 'motif',
487
+ 'christmas market',
488
+ 'glass window',
489
+ 'college',
490
+ 'wheat',
491
+ 'damage',
492
+ 'rectangle',
493
+ 'picture frame',
494
+ 'chess',
495
+ 'guest room',
496
+ 'street corner',
497
+ 'religion',
498
+ 'seed',
499
+ 'puzzle',
500
+ 'freeway',
501
+ 'beauty',
502
+ 'ocean',
503
+ 'watch',
504
+ 'mother',
505
+ 'garage',
506
+ 'quote',
507
+ 'dj',
508
+ 'supporter',
509
+ 'hip hop artist',
510
+ 'muffin',
511
+ 'eiffel tower',
512
+ 'cash',
513
+ 'firefighter',
514
+ 'cauliflower',
515
+ 'bunker',
516
+ 'sled',
517
+ 'manicure',
518
+ 'shark',
519
+ 'stall',
520
+ 'jungle',
521
+ 'family home',
522
+ 'tour bus',
523
+ 'chimney',
524
+ 'touchdown',
525
+ 'roundabout',
526
+ 'coyote',
527
+ 'street scene',
528
+ 'tank',
529
+ 'wedding dress',
530
+ 'mantle',
531
+ 'bedroom window',
532
+ 'coconut',
533
+ 'chapel',
534
+ 'goat',
535
+ 'living space',
536
+ 'rock wall',
537
+ 'polka dot',
538
+ 'railway',
539
+ 'mandala',
540
+ 'mango',
541
+ 'lesson',
542
+ 'mountain landscape',
543
+ 'team photo',
544
+ 'bookshelf',
545
+ 'meter',
546
+ 'bulldog',
547
+ 'evening sun',
548
+ 'stick',
549
+ 'card',
550
+ 'pink',
551
+ 'fish pond',
552
+ 'paint',
553
+ 'pill',
554
+ 'cart',
555
+ 'pea',
556
+ 'van',
557
+ 'album',
558
+ 'football college game',
559
+ 'mountain pass',
560
+ 'doughnut',
561
+ 'ski slope',
562
+ 'match',
563
+ 'official',
564
+ 'shadow',
565
+ 'organ',
566
+ 'celebration',
567
+ 'coin',
568
+ 'log cabin',
569
+ 'firework display',
570
+ 'present',
571
+ 'twig',
572
+ 'chef',
573
+ 'confetti',
574
+ 'footpath',
575
+ 'tour',
576
+ 'ponytail',
577
+ 'artwork',
578
+ 'race car',
579
+ 'club',
580
+ 'season',
581
+ 'hose',
582
+ 'pencil',
583
+ 'aircraft',
584
+ 'rock formation',
585
+ 'wardrobe',
586
+ 'participant',
587
+ 'politician',
588
+ 'engineer',
589
+ 'peace',
590
+ 'filter',
591
+ 'sailing boat',
592
+ 'water bottle',
593
+ 'service dog',
594
+ 'poodle',
595
+ 'loki',
596
+ 'statesman',
597
+ 'sleeping bag',
598
+ 'outskirt',
599
+ 'clock',
600
+ 'factory',
601
+ 'oak tree',
602
+ 'physician',
603
+ 'color',
604
+ 'room',
605
+ 'stairway',
606
+ 'company',
607
+ 'lady',
608
+ 'graph',
609
+ 'faucet',
610
+ 'tablecloth',
611
+ 'subway train',
612
+ 'chocolate chip cookie',
613
+ 'headquarters',
614
+ 'screw',
615
+ 'goggle',
616
+ 'halloween',
617
+ 'city street',
618
+ 'swirl',
619
+ 'cord',
620
+ 'forward',
621
+ 'bone',
622
+ 'bedding',
623
+ 'archway',
624
+ 'wig',
625
+ 'lobby',
626
+ 'mask',
627
+ 'attic',
628
+ 'kitchen table',
629
+ 'skylight',
630
+ 'fire',
631
+ 'exit',
632
+ 'oil painting',
633
+ 'passenger',
634
+ 'meditation',
635
+ 'salmon',
636
+ 'fedora',
637
+ 'rubber stamp',
638
+ 'orange juice',
639
+ 'arch',
640
+ 'scientist',
641
+ 'stroll',
642
+ 'manhattan',
643
+ 'float',
644
+ 'baseball uniform',
645
+ 'circle',
646
+ 'church',
647
+ 'decker bus',
648
+ 'competitor',
649
+ 'zoo',
650
+ 'basketball team',
651
+ 'tourist',
652
+ 'daughter',
653
+ 'silverware',
654
+ 'ceiling fan',
655
+ 'birth',
656
+ 'vase',
657
+ 'jack',
658
+ 'mushroom',
659
+ 'spiral',
660
+ 'cage',
661
+ 'limb',
662
+ 'salad',
663
+ 'ad',
664
+ 'control',
665
+ 'earth',
666
+ 'party',
667
+ 'bolt',
668
+ 'tractor',
669
+ 'barley',
670
+ 'wedding photo',
671
+ 'hawk',
672
+ 'warehouse',
673
+ 'vegetable garden',
674
+ 'chocolate cake',
675
+ 'cabbage',
676
+ 'floor window',
677
+ 'baby shower',
678
+ 'magnifying glass',
679
+ 'table',
680
+ 'stethoscope',
681
+ 'reading',
682
+ 'mission',
683
+ 'croissant',
684
+ 'gift box',
685
+ 'rocket',
686
+ 'forest road',
687
+ 'cooking',
688
+ 'suite',
689
+ 'hill country',
690
+ 'motorcycle',
691
+ 'baseball player',
692
+ 'angle',
693
+ 'drug',
694
+ 'sport association',
695
+ 'championship',
696
+ 'family portrait',
697
+ 'florist',
698
+ 'softball',
699
+ 'egret',
700
+ 'office',
701
+ 'plywood',
702
+ 'jockey',
703
+ 'mosque',
704
+ 'brunch',
705
+ 'beanie',
706
+ 'office building',
707
+ 'pattern',
708
+ 'calendar',
709
+ 'indoor',
710
+ 'pepper',
711
+ 'ledge',
712
+ 'trail',
713
+ 'fuel',
714
+ 'laptop computer',
715
+ 'tennis shoe',
716
+ 'deck chair',
717
+ 'guitarist',
718
+ 'barn',
719
+ 'surgery',
720
+ 'cartoon illustration',
721
+ 'nebula',
722
+ 'railroad',
723
+ 'mountain goat',
724
+ 'goose',
725
+ 'car door',
726
+ 'cheer',
727
+ 'liquid',
728
+ 'hardwood floor',
729
+ 'pathway',
730
+ 'acorn',
731
+ 'gull',
732
+ 'airliner',
733
+ 'couch',
734
+ 'lake house',
735
+ 'spaghetti',
736
+ 'promenade',
737
+ 'collection',
738
+ 'garden',
739
+ 'bank',
740
+ 'robin',
741
+ 'tennis ball',
742
+ 'peony',
743
+ 'gymnast',
744
+ 'lavender',
745
+ 'deck',
746
+ 'test',
747
+ 'riverside',
748
+ 'rapper',
749
+ 'domino',
750
+ 'bride',
751
+ 'mouse',
752
+ 'basil',
753
+ 'wedding couple',
754
+ 'ocean wave',
755
+ 'arm',
756
+ 'kitchen floor',
757
+ 'grove',
758
+ 'family member',
759
+ 'backyard',
760
+ 'raspberry',
761
+ 'forest fire',
762
+ 'officer',
763
+ 'hibiscus',
764
+ 'canyon',
765
+ 'composer',
766
+ 'signature',
767
+ 'olive oil',
768
+ 'hibiscus flower',
769
+ 'rose',
770
+ 'vector icon',
771
+ 'sunrise',
772
+ 'horseback',
773
+ 'motor scooter',
774
+ 'office worker',
775
+ 'tradition',
776
+ 'ingredient',
777
+ 'washing machine',
778
+ 'lighting',
779
+ 'bagel',
780
+ 'sailboat',
781
+ 'policeman',
782
+ 'mare',
783
+ 'graphic',
784
+ 'halloween pumpkin',
785
+ 'stock',
786
+ 'pilot',
787
+ 'education',
788
+ 'team',
789
+ 'body',
790
+ 'horse',
791
+ 'kimono',
792
+ 'bazaar',
793
+ 'bag',
794
+ 'recording studio',
795
+ 'parsley',
796
+ 'entrance',
797
+ 'denim',
798
+ 'vet',
799
+ 'horse farm',
800
+ 'charcoal',
801
+ 'architecture',
802
+ 'glass vase',
803
+ 'puppy',
804
+ 'estuary',
805
+ 'television show host',
806
+ 'city bus',
807
+ 'shoulder',
808
+ 'beast',
809
+ 'balance',
810
+ 'golfer',
811
+ 'roadside',
812
+ 'denim jacket',
813
+ 'stone wall',
814
+ 'counter top',
815
+ 'app icon',
816
+ 'toast',
817
+ 'head coach',
818
+ 'ham',
819
+ 'warrior',
820
+ 'gem',
821
+ 'refrigerator',
822
+ 'snowman',
823
+ 'construction worker',
824
+ 'coal',
825
+ 'website',
826
+ 'morning fog',
827
+ 'mustard',
828
+ 'human',
829
+ 'owl',
830
+ 'puppy dog',
831
+ 'piggy bank',
832
+ 'vegetation',
833
+ 'pirate',
834
+ 'action film',
835
+ 'marshmallow',
836
+ 'thanksgiving',
837
+ 'business',
838
+ 'disease',
839
+ 'signage',
840
+ 'greeting',
841
+ 'skate park',
842
+ 'tile',
843
+ 'mouth',
844
+ 'spinach',
845
+ 'vacation',
846
+ 'leader',
847
+ 'shrine',
848
+ 'walker',
849
+ 'science fiction film',
850
+ 'bill',
851
+ 'rabbit',
852
+ 'motor boat',
853
+ 'bar',
854
+ 'radio',
855
+ 'barge',
856
+ 'tail',
857
+ 'chainsaw',
858
+ 'gallery',
859
+ 'rainbow',
860
+ 'pasta',
861
+ 'padlock',
862
+ 'web',
863
+ 'pastry',
864
+ 'ink',
865
+ 'reef',
866
+ 'school uniform',
867
+ 'shawl',
868
+ 'treasure',
869
+ 'peach',
870
+ 'dinner table',
871
+ 'injury',
872
+ 'harbor',
873
+ 'witch',
874
+ 'car dealership',
875
+ 'litter',
876
+ 'gesture',
877
+ 'documentary',
878
+ 'marriage',
879
+ 'sea shell',
880
+ 'priest',
881
+ 'dome',
882
+ 'kit',
883
+ 'icon',
884
+ 'seaside',
885
+ 'bucket',
886
+ 'entertainment',
887
+ 'stable',
888
+ 'hat',
889
+ 'puddle',
890
+ 'sock',
891
+ 'shopper',
892
+ 'technology',
893
+ 'harbour',
894
+ 'orbit',
895
+ 'antler',
896
+ 'tube',
897
+ 'flag waving',
898
+ 'cook',
899
+ 'tight',
900
+ 'commander',
901
+ 'farmland',
902
+ 'switch',
903
+ 'hiker',
904
+ 'wedding ceremony',
905
+ 'award ceremony',
906
+ 'champion',
907
+ 'chopstick',
908
+ 'farmhouse',
909
+ 'performer',
910
+ 'spike',
911
+ 'accident',
912
+ 'cruise ship',
913
+ 'passenger train',
914
+ 'attraction',
915
+ 'entertainer',
916
+ 'rear view',
917
+ 'sidewalk',
918
+ 'parade',
919
+ 'racing',
920
+ 'plane',
921
+ 'ritual',
922
+ 'peacock',
923
+ 'pocket',
924
+ 'plum',
925
+ 'drop',
926
+ 'carrot',
927
+ 'floor',
928
+ 'sunset',
929
+ 'troop',
930
+ 'architect',
931
+ 'coffee table',
932
+ 'dust',
933
+ 'outline',
934
+ 'leather',
935
+ 'charity event',
936
+ 'heat',
937
+ 'whale',
938
+ 'laundry',
939
+ 'coconut tree',
940
+ 'crosswalk',
941
+ 'pony',
942
+ 'ant',
943
+ 'pipe',
944
+ 'string',
945
+ 'coat',
946
+ 'angel',
947
+ 'beef',
948
+ 'church tower',
949
+ 'dish',
950
+ 'pitch',
951
+ 'cupboard',
952
+ 'thermometer',
953
+ 'dirt field',
954
+ 'fireworks',
955
+ 'minute',
956
+ 'cane',
957
+ 'pajama',
958
+ 'flower garden',
959
+ 'autumn',
960
+ 'trash can',
961
+ 'dachshund',
962
+ 'banana tree',
963
+ 'tray',
964
+ 'moose',
965
+ 'roadway',
966
+ 'carnival',
967
+ 'antenna',
968
+ 'pole',
969
+ 'castle wall',
970
+ 'ram',
971
+ 'cattle',
972
+ 'hay',
973
+ 'cookie',
974
+ 'swimmer',
975
+ 'baseball team',
976
+ 'strait',
977
+ 'hedge',
978
+ 'jet',
979
+ 'fire pit',
980
+ 'octopus',
981
+ 'calf',
982
+ 'cube',
983
+ 'opera',
984
+ 'cardboard box',
985
+ 'tiara',
986
+ 'kitchen sink',
987
+ 'prairie',
988
+ 'bowl',
989
+ 'galaxy',
990
+ 'straw hat',
991
+ 'linen',
992
+ 'ski resort',
993
+ 'stitch',
994
+ 'street lamp',
995
+ 'motorist',
996
+ 'icicle',
997
+ 'stain',
998
+ 'flora',
999
+ 'drain',
1000
+ 'kitchen cabinet',
1001
+ 'decor',
1002
+ 'bouquet',
1003
+ 'pound',
1004
+ 'interior design',
1005
+ 'nail polish',
1006
+ 'figurine',
1007
+ 'tomb',
1008
+ 'disc',
1009
+ 'twist',
1010
+ 'blouse',
1011
+ 'ribbon',
1012
+ 'figure',
1013
+ 'burger',
1014
+ 'cork',
1015
+ 'soccer goalkeeper',
1016
+ 'train bridge',
1017
+ 'drinking water',
1018
+ 'dew',
1019
+ 'baker',
1020
+ 'storm cloud',
1021
+ 'tarmac',
1022
+ 'tv drama',
1023
+ 'sponge',
1024
+ 'magnet',
1025
+ 'sailor',
1026
+ 'entry',
1027
+ 'swan',
1028
+ 'exercise',
1029
+ 'sloth',
1030
+ 'jewel',
1031
+ 'scuba diver',
1032
+ 'bite',
1033
+ 'cat tree',
1034
+ 'tent',
1035
+ 'can',
1036
+ 'tennis match',
1037
+ 'ecosystem',
1038
+ 'picket fence',
1039
+ 'palm',
1040
+ 'train car',
1041
+ 'frying pan',
1042
+ 'rally',
1043
+ 'tablet pc',
1044
+ 'reindeer',
1045
+ 'image',
1046
+ 'wolf',
1047
+ 'chin',
1048
+ 'conservatory',
1049
+ 'flood water',
1050
+ 'cityscape',
1051
+ 'beach sand',
1052
+ 'car park',
1053
+ 'pavement',
1054
+ 'farm field',
1055
+ 'swimming',
1056
+ 'winter storm',
1057
+ 'stem',
1058
+ 'pillow',
1059
+ 'inning',
1060
+ 'gorilla',
1061
+ 'desk',
1062
+ 'avenue',
1063
+ 'fern',
1064
+ 'money',
1065
+ 'pearl',
1066
+ 'train station',
1067
+ 'skillet',
1068
+ 'nap',
1069
+ 'barber',
1070
+ 'library',
1071
+ 'freezer',
1072
+ 'label',
1073
+ 'rainforest',
1074
+ 'parking sign',
1075
+ 'mirror',
1076
+ 'wing',
1077
+ 'noodle',
1078
+ 'press room',
1079
+ 'sculpture',
1080
+ 'tablet',
1081
+ 'viewer',
1082
+ 'prayer',
1083
+ 'mini',
1084
+ 'mechanic',
1085
+ 'laugh',
1086
+ 'rice field',
1087
+ 'hand',
1088
+ 'mustache',
1089
+ 'mountain road',
1090
+ 'catwalk',
1091
+ 'conference',
1092
+ 'cape',
1093
+ 'installation',
1094
+ 'musician',
1095
+ 'stream',
1096
+ 'machine',
1097
+ 'speech',
1098
+ 'crocodile',
1099
+ 'soccer match',
1100
+ 'town square',
1101
+ 'passport',
1102
+ 'post box',
1103
+ 'point',
1104
+ 'stone building',
1105
+ 'motorway',
1106
+ 'mix',
1107
+ 'dentist',
1108
+ 'businessperson',
1109
+ 'happiness',
1110
+ 'boat',
1111
+ 'vineyard',
1112
+ 'treadmill',
1113
+ 'glass wall',
1114
+ 'water droplet',
1115
+ 'coffee mug',
1116
+ 'graduate',
1117
+ 'sunflower',
1118
+ 'parliament',
1119
+ 'shepherd',
1120
+ 'movie',
1121
+ 'wine',
1122
+ 'orchard',
1123
+ 'tulip',
1124
+ 'motherboard',
1125
+ 'cup',
1126
+ 'broom',
1127
+ 'spot',
1128
+ 'drawing',
1129
+ 'polo shirt',
1130
+ 'graduation',
1131
+ 'film producer',
1132
+ 'moonlight',
1133
+ 'glow',
1134
+ 'film format',
1135
+ 't shirt',
1136
+ 'rock face',
1137
+ 'sword',
1138
+ 'clinic',
1139
+ 'festival day',
1140
+ 'meadow',
1141
+ 'staple',
1142
+ 'pupil',
1143
+ 'training ground',
1144
+ 'rider',
1145
+ 'flower',
1146
+ 'foal',
1147
+ 'wharf',
1148
+ 'foot bridge',
1149
+ 'shooting',
1150
+ 'top',
1151
+ 'mast',
1152
+ 'police car',
1153
+ 'robe',
1154
+ 'wedding bouquet',
1155
+ 'stop sign',
1156
+ 'birthday cake',
1157
+ 'glitter',
1158
+ 'butter',
1159
+ 'scooter',
1160
+ 'tundra',
1161
+ 'superhero',
1162
+ 'pocket watch',
1163
+ 'inscription',
1164
+ 'youngster',
1165
+ 'fruit tree',
1166
+ 'movie poster',
1167
+ 'engine',
1168
+ 'foundation',
1169
+ 'motorcyclist',
1170
+ 'take',
1171
+ 'woman',
1172
+ 'antelope',
1173
+ 'country artist',
1174
+ 'road trip',
1175
+ 'typewriter',
1176
+ 'tuxedo',
1177
+ 'brand',
1178
+ 'pine',
1179
+ 'bathroom',
1180
+ 'paradise',
1181
+ 'texture',
1182
+ 'balloon',
1183
+ 'dining table',
1184
+ 'home',
1185
+ 'computer screen',
1186
+ 'actor',
1187
+ 'clip',
1188
+ 'tv tower',
1189
+ 'panorama',
1190
+ 'summit',
1191
+ 'cat',
1192
+ 'plot',
1193
+ 'eagle',
1194
+ 'dancer',
1195
+ 'pup',
1196
+ 'studio shot',
1197
+ 'tear',
1198
+ 'bird bath',
1199
+ 'classroom',
1200
+ 'bookstore',
1201
+ 'city wall',
1202
+ 'tv programme',
1203
+ 'blade',
1204
+ 'easel',
1205
+ 'buttercream',
1206
+ 'sweet',
1207
+ 'designer',
1208
+ 'diamond',
1209
+ 'handshake',
1210
+ 'herb',
1211
+ 'corn field',
1212
+ 'seafront',
1213
+ 'concrete',
1214
+ 'street artist',
1215
+ 'gas',
1216
+ 'stamp',
1217
+ 'window display',
1218
+ 'paper',
1219
+ 'note',
1220
+ 'pint',
1221
+ 'quarry',
1222
+ 'research',
1223
+ 'fixture',
1224
+ 'manager',
1225
+ 'soil',
1226
+ 'leopard',
1227
+ 'board game',
1228
+ 'ladder',
1229
+ 'stop light',
1230
+ 'island',
1231
+ 'ramp',
1232
+ 'football match',
1233
+ 'icing',
1234
+ 'drill',
1235
+ 'currency',
1236
+ 'summer evening',
1237
+ 'topping',
1238
+ 'pyramid',
1239
+ 'pomegranate',
1240
+ 'cell',
1241
+ 'ivy',
1242
+ 'squad',
1243
+ 'scenery',
1244
+ 'computer',
1245
+ 'locomotive',
1246
+ 'surf',
1247
+ 'mascot',
1248
+ 'dune',
1249
+ 'path',
1250
+ 'duck',
1251
+ 'twilight',
1252
+ 'wire',
1253
+ 'bow tie',
1254
+ 'strike',
1255
+ 'cormorant',
1256
+ 'car wash',
1257
+ 'crane',
1258
+ 'market',
1259
+ 'philosopher',
1260
+ 'alarm clock',
1261
+ 'camera',
1262
+ 'birch',
1263
+ 'greeting card',
1264
+ 'plain',
1265
+ 'clay',
1266
+ 'donut',
1267
+ 'lock',
1268
+ 'moth',
1269
+ 'laboratory',
1270
+ 'fan',
1271
+ 'violin',
1272
+ 'jazz fusion artist',
1273
+ 'mountain biker',
1274
+ 'terrain',
1275
+ 'magazine',
1276
+ 'pickup',
1277
+ 'comedy film',
1278
+ 'smartphone',
1279
+ 'film',
1280
+ 'bed',
1281
+ 'microwave oven',
1282
+ 'tournament',
1283
+ 'lawn',
1284
+ 'car window',
1285
+ 'alligator',
1286
+ 'screen',
1287
+ 'jetty',
1288
+ 'shopping bag',
1289
+ 'landscape view',
1290
+ 'cabinetry',
1291
+ 'friendly match',
1292
+ 'thing',
1293
+ 'petal',
1294
+ 'shopping center',
1295
+ 'transport',
1296
+ 'ballet dancer',
1297
+ 'shoreline',
1298
+ 'princess',
1299
+ 'car seat',
1300
+ 'parking meter',
1301
+ 'green',
1302
+ 'vodka',
1303
+ 'band',
1304
+ 'rock',
1305
+ 'costume',
1306
+ 'warning sign',
1307
+ 'strip',
1308
+ 'plaque',
1309
+ 'wheelchair',
1310
+ 'headband',
1311
+ 'ginger',
1312
+ 'dice',
1313
+ 'media',
1314
+ 'hairdresser',
1315
+ 'press',
1316
+ 'living room',
1317
+ 'stove',
1318
+ 'player',
1319
+ 'cherry',
1320
+ 'workshop',
1321
+ 'carving',
1322
+ 'embroidery',
1323
+ 'doodle',
1324
+ 'adventure',
1325
+ 'rugby player',
1326
+ 'monument',
1327
+ 'brush',
1328
+ 'marker',
1329
+ 'loft',
1330
+ 'postcard',
1331
+ 'collage',
1332
+ 'ball',
1333
+ 'professor',
1334
+ 'dresser',
1335
+ 'gig',
1336
+ 'festival',
1337
+ 'blackbird',
1338
+ 'makeup artist',
1339
+ 'video camera',
1340
+ 'sticker',
1341
+ 'peak',
1342
+ 'wildflower',
1343
+ 'santa hat',
1344
+ 'rodeo',
1345
+ 'wedding photographer',
1346
+ 'guy',
1347
+ 'staff',
1348
+ 'waterfall',
1349
+ 'operation',
1350
+ 'defender',
1351
+ 'falcon',
1352
+ 'haze',
1353
+ 'individual',
1354
+ 'gentleman',
1355
+ 'greyhound',
1356
+ 'rocking chair',
1357
+ 'rice',
1358
+ 'garbage',
1359
+ 'platter',
1360
+ 'chocolate',
1361
+ 'splash',
1362
+ 'business suit',
1363
+ 'cheetah',
1364
+ 'valley',
1365
+ 'maze',
1366
+ 'trampoline',
1367
+ 'garland',
1368
+ 'slalom',
1369
+ 'unicorn',
1370
+ 'tree stump',
1371
+ 'painting',
1372
+ 'romance',
1373
+ 'fight',
1374
+ 'alcohol',
1375
+ 'ghost',
1376
+ 'fondant',
1377
+ 'spa',
1378
+ 'shutter',
1379
+ 'death',
1380
+ 'demonstration',
1381
+ 'cotton',
1382
+ 'pier',
1383
+ 'flea market',
1384
+ 'history',
1385
+ 'savannah',
1386
+ 'fist',
1387
+ 'aisle',
1388
+ 'crew',
1389
+ 'jug',
1390
+ 'pose',
1391
+ 'anchor',
1392
+ 'teapot',
1393
+ 'boat house',
1394
+ 'business team',
1395
+ 'tripod',
1396
+ 'bee',
1397
+ 'pebble',
1398
+ 'mattress',
1399
+ 'canvas',
1400
+ 'hallway',
1401
+ 'campaign',
1402
+ 'pod',
1403
+ 'lake district',
1404
+ 'article',
1405
+ 'white',
1406
+ 'sofa',
1407
+ 'honey',
1408
+ 'marathon',
1409
+ 'pancake',
1410
+ 'tourist attraction',
1411
+ 'wedding gown',
1412
+ 'battle',
1413
+ 'shelving',
1414
+ 'sea',
1415
+ 'sheet music',
1416
+ 'pie',
1417
+ 'yarn',
1418
+ 'construction site',
1419
+ 'flyer',
1420
+ 'tie',
1421
+ 'star',
1422
+ 'lettuce',
1423
+ 'martial artist',
1424
+ 'dart',
1425
+ 'straw',
1426
+ 'reflection',
1427
+ 'conference room',
1428
+ 'temperature',
1429
+ 'rugby',
1430
+ 'mosquito',
1431
+ 'physicist',
1432
+ 'rock climber',
1433
+ 'crash',
1434
+ 'backdrop',
1435
+ 'toilet seat',
1436
+ 'sand castle',
1437
+ 'water park',
1438
+ 'toy car',
1439
+ 'waste',
1440
+ 'luxury',
1441
+ 'hangar',
1442
+ 'rv',
1443
+ 'tree trunk',
1444
+ 'board',
1445
+ 'gold',
1446
+ 'project picture',
1447
+ 'cap',
1448
+ 'cottage',
1449
+ 'relief',
1450
+ 'attire',
1451
+ 'microscope',
1452
+ 'battery',
1453
+ 'roll',
1454
+ 'line',
1455
+ 'parking garage',
1456
+ 'crystal',
1457
+ 'broadcasting',
1458
+ 'brick wall',
1459
+ 'lab',
1460
+ 'flooring',
1461
+ 'meeting',
1462
+ '3d cg rendering',
1463
+ 'desktop computer',
1464
+ 'cowboy',
1465
+ 'sailing ship',
1466
+ 'junction',
1467
+ 'hairstyle',
1468
+ 'homework',
1469
+ 'profile',
1470
+ 'model',
1471
+ 'flower pot',
1472
+ 'street light',
1473
+ 'salt lake',
1474
+ 'maple',
1475
+ 'space',
1476
+ 'blizzard',
1477
+ 'throw',
1478
+ 'zebras',
1479
+ 'brochure',
1480
+ 'constellation',
1481
+ 'beak',
1482
+ 'kilt',
1483
+ 'pond',
1484
+ 'blue sky',
1485
+ 'sneaker',
1486
+ 'sand dune',
1487
+ 'morning sun',
1488
+ 'almond',
1489
+ 'grill',
1490
+ 'curl',
1491
+ 'basketball girl game',
1492
+ 'chameleon',
1493
+ 'toilet bowl',
1494
+ 'prince',
1495
+ 'keyboard',
1496
+ 'queen',
1497
+ 'computer monitor',
1498
+ 'writing',
1499
+ 'crown',
1500
+ 'basilica',
1501
+ 'kiss',
1502
+ 'house',
1503
+ 'parking',
1504
+ 'football competition',
1505
+ 'shell',
1506
+ 'sport equipment',
1507
+ 'comedy',
1508
+ 'baboon',
1509
+ 'vendor',
1510
+ 'rise building',
1511
+ 'wrap',
1512
+ 'food truck',
1513
+ 'cat bed',
1514
+ 'rickshaw',
1515
+ 'flare',
1516
+ 'teal',
1517
+ 'nectar',
1518
+ 'eclipse',
1519
+ 'vehicle',
1520
+ 'steam locomotive',
1521
+ 'gorge',
1522
+ 'cow',
1523
+ 'christmas card',
1524
+ 'demonstrator',
1525
+ 'memorial',
1526
+ 'towel',
1527
+ 'jewellery',
1528
+ 'train',
1529
+ 'frisbee',
1530
+ 'baseball game',
1531
+ 'fur',
1532
+ 'afternoon sun',
1533
+ 'community',
1534
+ 'sparkler',
1535
+ 'bandage',
1536
+ 'firework',
1537
+ 'dollar',
1538
+ 'pasture',
1539
+ 'video',
1540
+ 'bus',
1541
+ 'tree house',
1542
+ 'seashore',
1543
+ 'field',
1544
+ 'hamburger',
1545
+ 'souvenir',
1546
+ 'hedgehog',
1547
+ 'worm',
1548
+ 'pine cone',
1549
+ 'osprey',
1550
+ 'dinosaur',
1551
+ 'vegetable',
1552
+ 'junk',
1553
+ 'poster',
1554
+ 'army',
1555
+ 'winger',
1556
+ 'bundle',
1557
+ 'stage',
1558
+ 'growth',
1559
+ 'wedding party',
1560
+ 'service',
1561
+ 'blanket',
1562
+ 'ruler',
1563
+ 'eye',
1564
+ 'credit card',
1565
+ 'castle',
1566
+ 'diner',
1567
+ 'hut',
1568
+ 'elk',
1569
+ 'hard rock artist',
1570
+ 'nun',
1571
+ 'dog breed',
1572
+ 'nest',
1573
+ 'drama film',
1574
+ 'number icon',
1575
+ 'water tank',
1576
+ 'giraffe',
1577
+ 'altar',
1578
+ 'pavilion',
1579
+ 'tv personality',
1580
+ 'suv',
1581
+ 'street vendor',
1582
+ 'street sign',
1583
+ 'ditch',
1584
+ 'debris',
1585
+ 'foam',
1586
+ 'takeoff',
1587
+ 'spice',
1588
+ 'mountain lake',
1589
+ 'tea',
1590
+ 'orchestra',
1591
+ 'spacecraft',
1592
+ 'counter',
1593
+ 'abbey',
1594
+ 'mountain',
1595
+ 'hydrangea',
1596
+ 'racer',
1597
+ 'orange tree',
1598
+ 'tide',
1599
+ 'cowboy hat',
1600
+ 'rapid',
1601
+ 'town',
1602
+ 'wild',
1603
+ 'herd',
1604
+ 'vein',
1605
+ 'driveway',
1606
+ 'jar',
1607
+ 'bark',
1608
+ 'illustration',
1609
+ 'horror film',
1610
+ 'corn',
1611
+ 'stroller',
1612
+ 'industry',
1613
+ 'mountain stream',
1614
+ 'gym',
1615
+ 'neckline',
1616
+ 'pan',
1617
+ 'client',
1618
+ 'spectator',
1619
+ 'eggplant',
1620
+ 'camper',
1621
+ 'fawn',
1622
+ 'hoodie',
1623
+ 'meat',
1624
+ 'lemonade',
1625
+ 'food market',
1626
+ 'slum',
1627
+ 'comic book character',
1628
+ 'flower market',
1629
+ 'love',
1630
+ 'palace',
1631
+ 'gun',
1632
+ 'heel',
1633
+ 'shopping street',
1634
+ 'shooting basketball guard',
1635
+ 'family photo',
1636
+ 'rooftop',
1637
+ 'laundry basket',
1638
+ 'airport runway',
1639
+ 'horn',
1640
+ 'face mask',
1641
+ 'flight',
1642
+ 'appetizer',
1643
+ 'violet',
1644
+ 'country lane',
1645
+ 'cement',
1646
+ 'instrument',
1647
+ 'tv actor',
1648
+ 'spark',
1649
+ 'celebrity',
1650
+ 'award',
1651
+ 'country house',
1652
+ 'standing',
1653
+ 'auction',
1654
+ 'date',
1655
+ 'engagement',
1656
+ 'puck',
1657
+ 'advertisement',
1658
+ 'chair',
1659
+ 'zebra',
1660
+ 'driftwood',
1661
+ 'bumblebee',
1662
+ 'maple leaf',
1663
+ 'bonnet',
1664
+ 'orange',
1665
+ 'water tower',
1666
+ 'door',
1667
+ 'singer',
1668
+ 'floor plan',
1669
+ 'discussion',
1670
+ 'theatre',
1671
+ 'pilgrim',
1672
+ 'mug',
1673
+ 'branch',
1674
+ 'window sill',
1675
+ 'baseball pitcher',
1676
+ 'bakery',
1677
+ 'lollipop',
1678
+ 'basketball player',
1679
+ 'toilet paper',
1680
+ 'chalkboard',
1681
+ 'cabin',
1682
+ 'sign',
1683
+ 'night sky',
1684
+ 'cannon',
1685
+ 'fishing net',
1686
+ 'submarine',
1687
+ 'suit',
1688
+ 'fur coat',
1689
+ 'wine bottle',
1690
+ 'folder',
1691
+ 'street art',
1692
+ 'suspension bridge',
1693
+ 'evening sky',
1694
+ 'billboard',
1695
+ 'postage stamp',
1696
+ 'newspaper',
1697
+ 'transportation',
1698
+ 'surgeon',
1699
+ 'light',
1700
+ 'park',
1701
+ 'horizon',
1702
+ 'road',
1703
+ 'sand bar',
1704
+ 'trumpet',
1705
+ 'lounge',
1706
+ 'cloud forest',
1707
+ 'birthday celebration',
1708
+ 'balcony',
1709
+ 'anime',
1710
+ 'beehive',
1711
+ 'umbrella',
1712
+ 'goldfish',
1713
+ 'baseball cap',
1714
+ 'waterhole',
1715
+ 'ceiling',
1716
+ 'carousel',
1717
+ 'backpack',
1718
+ 'plant pot',
1719
+ 'atmosphere',
1720
+ 'sunflower field',
1721
+ 'spire',
1722
+ 'vision',
1723
+ 'woodpecker',
1724
+ 'chip',
1725
+ 'pool table',
1726
+ 'lotus flower',
1727
+ 'cone',
1728
+ 'humpback whale',
1729
+ 'reservoir',
1730
+ 'hunt',
1731
+ 'piano',
1732
+ 'plate',
1733
+ 'dining area',
1734
+ 'luggage',
1735
+ 'skier',
1736
+ 'dance floor',
1737
+ 'crow',
1738
+ 'stair',
1739
+ 'overpass',
1740
+ 'opera house',
1741
+ 'bear',
1742
+ 'jazz artist',
1743
+ 'water',
1744
+ 'vessel',
1745
+ 'cast',
1746
+ 'yard',
1747
+ 'cathedral',
1748
+ 'basketball hoop',
1749
+ 'graveyard',
1750
+ 'sound',
1751
+ 'berry',
1752
+ 'onlooker',
1753
+ 'fauna',
1754
+ 'birch tree',
1755
+ 'retail',
1756
+ 'hill',
1757
+ 'skeleton',
1758
+ 'journalist',
1759
+ 'frost',
1760
+ 'basket',
1761
+ 'nail',
1762
+ 'dusk',
1763
+ 'trash',
1764
+ 'dawn',
1765
+ 'clover',
1766
+ 'hen',
1767
+ 'volcano',
1768
+ 'basketball coach',
1769
+ 'home decor',
1770
+ 'charge',
1771
+ 'haircut',
1772
+ 'sense',
1773
+ 'university',
1774
+ 'lizard',
1775
+ 'daisy',
1776
+ 'tablet computer',
1777
+ 'grass field',
1778
+ 'prison',
1779
+ 'metal artist',
1780
+ 'bathroom mirror',
1781
+ 'window frame',
1782
+ 'chest',
1783
+ 'flavor',
1784
+ 'pop country artist',
1785
+ 'market square',
1786
+ 'monkey',
1787
+ 'blog',
1788
+ 'deer',
1789
+ 'speech bubble',
1790
+ 'dog',
1791
+ 'independence day',
1792
+ 'girl',
1793
+ 'boy',
1794
+ 'tartan',
1795
+ 'furniture',
1796
+ 'appliance',
1797
+ 'office window',
1798
+ 'fish boat',
1799
+ 'sand box',
1800
+ 'tv sitcom',
1801
+ 'drama',
1802
+ 'sleigh',
1803
+ 'depression',
1804
+ 'paper towel',
1805
+ 'baseball',
1806
+ 'protestor',
1807
+ 'grape',
1808
+ 'wedding cake',
1809
+ 'invitation',
1810
+ 'accessory',
1811
+ 'pick',
1812
+ 'grandparent',
1813
+ 'racket',
1814
+ 'tea plantation',
1815
+ 'outdoors',
1816
+ 'egg',
1817
+ 'glass bowl',
1818
+ 'sun',
1819
+ 'organization',
1820
+ 'lion',
1821
+ 'panel',
1822
+ 'station',
1823
+ 'wallpaper',
1824
+ 'helicopter',
1825
+ 'salt',
1826
+ 'vanity',
1827
+ 'patio',
1828
+ 'lunch',
1829
+ 'street performer',
1830
+ 'mountain range',
1831
+ 'soup',
1832
+ 'bacon',
1833
+ 'power station',
1834
+ 'cantilever bridge',
1835
+ 'hummingbird',
1836
+ 'shirt',
1837
+ 'rope',
1838
+ 'hip',
1839
+ 'chalk',
1840
+ 'pendant',
1841
+ 'choir',
1842
+ 'tv',
1843
+ 'lichen',
1844
+ 'railway bridge',
1845
+ 'art gallery',
1846
+ 'bartender',
1847
+ 'wagon',
1848
+ 'baby elephant',
1849
+ 'accordion',
1850
+ 'horseshoe',
1851
+ 'building site',
1852
+ 'clutch',
1853
+ 'harvest',
1854
+ 'savanna',
1855
+ 'geranium',
1856
+ 'business woman',
1857
+ 'paddock',
1858
+ 'patch',
1859
+ 'beech tree',
1860
+ 'war',
1861
+ 'suburbs',
1862
+ 'hospital bed',
1863
+ 'motorcycle racer',
1864
+ 'moss',
1865
+ 'gravel',
1866
+ 'government agency',
1867
+ 'dollar bill',
1868
+ 'father',
1869
+ 'fjord',
1870
+ 'concert',
1871
+ 'nut',
1872
+ 'wedding photography',
1873
+ 'finish line',
1874
+ 'home plate',
1875
+ 'food',
1876
+ 'nose',
1877
+ 'thumb',
1878
+ 'village',
1879
+ 'dining room table',
1880
+ 'bumper',
1881
+ 'monster',
1882
+ 'blackberry',
1883
+ 'lime',
1884
+ 'conflict',
1885
+ 'gala',
1886
+ 'wallet',
1887
+ 'wrist',
1888
+ 'hug',
1889
+ 'mermaid',
1890
+ 'lava',
1891
+ 'lawyer',
1892
+ 'folk rock artist',
1893
+ 'arena',
1894
+ 'onion',
1895
+ 'toothbrush',
1896
+ 'fashion',
1897
+ 'perfume',
1898
+ 'flip',
1899
+ 'triangle',
1900
+ 'woodland',
1901
+ 'mail',
1902
+ 'grasshopper',
1903
+ 'studio',
1904
+ 'wood floor',
1905
+ 'den',
1906
+ 'racquet',
1907
+ 'cello',
1908
+ 'lemur',
1909
+ 'astronaut',
1910
+ 'glass table',
1911
+ 'blood',
1912
+ 'dvd',
1913
+ 'planter',
1914
+ 'silver',
1915
+ 'leash',
1916
+ 'master bedroom',
1917
+ 'forest',
1918
+ 'batter',
1919
+ 'shoe',
1920
+ 'engraving',
1921
+ 'opening',
1922
+ 'product',
1923
+ 'toe',
1924
+ 'cocktail',
1925
+ 'mallard duck',
1926
+ 'bike ride',
1927
+ 'oasis',
1928
+ 'wedding ring',
1929
+ 'cinematographer',
1930
+ 'holly',
1931
+ 'autograph',
1932
+ 'fence',
1933
+ 'ice cube',
1934
+ 'cove',
1935
+ 'pineapple',
1936
+ 'aurora',
1937
+ 'glass bead',
1938
+ 'produce',
1939
+ 'apartment building',
1940
+ 'cob',
1941
+ 'miniature',
1942
+ 'cockpit',
1943
+ 'flashlight',
1944
+ 'frog',
1945
+ 'sheep',
1946
+ 'groom',
1947
+ 'steel',
1948
+ 'watermelon',
1949
+ 'clip art',
1950
+ 'paper plate',
1951
+ 'ostrich',
1952
+ 'contour',
1953
+ 'mural',
1954
+ 'cub',
1955
+ 'paisley bandanna',
1956
+ 'winery',
1957
+ 'turn',
1958
+ 'handle',
1959
+ 'satellite',
1960
+ 'post',
1961
+ 'pork',
1962
+ 'child',
1963
+ 'asphalt',
1964
+ 'grocery store',
1965
+ 'vulture',
1966
+ 'trolley',
1967
+ 'nightclub',
1968
+ 'brick',
1969
+ 'trailer',
1970
+ 'compass',
1971
+ 'cereal',
1972
+ 'cafe',
1973
+ 'cartoon character',
1974
+ 'sugar',
1975
+ 'fiction book',
1976
+ 'glass floor',
1977
+ 'umpire',
1978
+ 'guitar',
1979
+ 'hamster',
1980
+ 'protester',
1981
+ 'airplane',
1982
+ 'garment',
1983
+ 'blazer',
1984
+ 'railway line',
1985
+ 'wedding',
1986
+ 'shoe box',
1987
+ 'parking lot',
1988
+ 'construction',
1989
+ 'graduation ceremony',
1990
+ 'tram',
1991
+ 'telescope',
1992
+ 'copper',
1993
+ 'pain',
1994
+ 'autumn forest',
1995
+ 'guest house',
1996
+ 'partner',
1997
+ 'crayon',
1998
+ 'dip',
1999
+ 'boot',
2000
+ 'corridor',
2001
+ 'computer keyboard',
2002
+ 'hockey player',
2003
+ 'chicken coop',
2004
+ 'bus station',
2005
+ 'gathering',
2006
+ 'ankle',
2007
+ 'bunk bed',
2008
+ 'wood table',
2009
+ 'football coach',
2010
+ 'monarch',
2011
+ 'pharmacy',
2012
+ 'legging',
2013
+ 'mannequin',
2014
+ 'female',
2015
+ 'train track',
2016
+ 'stack',
2017
+ 'canopy',
2018
+ 'design element',
2019
+ 'grandmother',
2020
+ 'symbol',
2021
+ 'beach hut',
2022
+ 'zucchini',
2023
+ 'bomb',
2024
+ 'businessman',
2025
+ 'skyscraper',
2026
+ 'tongue',
2027
+ 'case',
2028
+ 'sparkle',
2029
+ 'highland',
2030
+ 'ballroom',
2031
+ 'prom',
2032
+ 'estate',
2033
+ 'customer',
2034
+ 'archipelago',
2035
+ 'cheese',
2036
+ 'debate',
2037
+ 'carriage',
2038
+ 'bulldozer',
2039
+ 'pumpkin',
2040
+ 'sitting room',
2041
+ 'gas station',
2042
+ 'wedding reception',
2043
+ 'camp',
2044
+ 'dog bed',
2045
+ 'tower',
2046
+ 'property',
2047
+ 'river bed',
2048
+ 'pop latin artist',
2049
+ 'fridge',
2050
+ 'wine glass',
2051
+ 'coast',
2052
+ 'beer',
2053
+ 'tow truck',
2054
+ 'fire truck',
2055
+ 'mountain bike',
2056
+ 'thigh',
2057
+ 'heron',
2058
+ 'boat ride',
2059
+ 'gondola',
2060
+ 'turquoise',
2061
+ 'lake',
2062
+ 'llama',
2063
+ 'kitty',
2064
+ 'tin',
2065
+ 'waiting room',
2066
+ 'coffee cup',
2067
+ 'socialite',
2068
+ 'guard',
2069
+ 'tap',
2070
+ 'waterway',
2071
+ 'forehead',
2072
+ 'list',
2073
+ 'erosion',
2074
+ 'box',
2075
+ 'sea lion',
2076
+ 'pollen',
2077
+ 'dam',
2078
+ 'wasp',
2079
+ 'salon',
2080
+ 'tennis tournament',
2081
+ 'flower box',
2082
+ 'aquarium',
2083
+ 'rain cloud',
2084
+ 'clothing store',
2085
+ 'lead singer',
2086
+ 'cupcake',
2087
+ 'tortoise',
2088
+ 'lettering',
2089
+ 'sport facility',
2090
+ 'dance',
2091
+ 'dog house',
2092
+ 'nature',
2093
+ 'football',
2094
+ 'rooster',
2095
+ 'footballer',
2096
+ 'railway track',
2097
+ 'crowd',
2098
+ 'fishing rod',
2099
+ 'silhouette',
2100
+ 'wind turbine',
2101
+ 'sari',
2102
+ 'bus window',
2103
+ 'cloud',
2104
+ 'charity',
2105
+ 'medal',
2106
+ 'yoga',
2107
+ 'event',
2108
+ 'veil',
2109
+ 'fashion menswear milan week',
2110
+ 'news',
2111
+ 'knife',
2112
+ 'print',
2113
+ 'screen tv',
2114
+ 'walnut',
2115
+ 'fungus',
2116
+ 'ice cream',
2117
+ 'computer mouse',
2118
+ 'play',
2119
+ 'tribe',
2120
+ 'picture',
2121
+ 'video game',
2122
+ 'business card',
2123
+ 'music festival',
2124
+ 'rack',
2125
+ 'envelope',
2126
+ 'shower',
2127
+ 'dirt road',
2128
+ 'mine',
2129
+ 'oyster',
2130
+ 'monarch butterfly',
2131
+ 'dude',
2132
+ 'fruit salad',
2133
+ 'podium',
2134
+ 'fork',
2135
+ 'lace',
2136
+ 'test match',
2137
+ 'boulder',
2138
+ 'cricket player',
2139
+ 'staircase',
2140
+ 'peninsula',
2141
+ 'shopping',
2142
+ 'popcorn',
2143
+ 'oak',
2144
+ 'market stall',
2145
+ 'pine tree',
2146
+ 'mountaineer',
2147
+ 'student',
2148
+ 'closet',
2149
+ 'hood',
2150
+ 'handstand',
2151
+ 'centerpiece',
2152
+ 'insect',
2153
+ 'patient',
2154
+ 'makeover',
2155
+ 'tennis player',
2156
+ 'sheet',
2157
+ 'park bench',
2158
+ 'apple',
2159
+ 'organism',
2160
+ 'hook',
2161
+ 'turkey',
2162
+ 'tangerine',
2163
+ 'sibling',
2164
+ 'shopping mall',
2165
+ 'bird',
2166
+ 'scarf',
2167
+ 'smoothie',
2168
+ 'net',
2169
+ 'grass',
2170
+ 'napkin',
2171
+ 'ray',
2172
+ 'eyebrow',
2173
+ 'laptop keyboard',
2174
+ 'motorbike',
2175
+ 'woman hand',
2176
+ 'oven',
2177
+ 'book cover',
2178
+ 'easter egg',
2179
+ 'microwave',
2180
+ 'sand',
2181
+ 'snapshot',
2182
+ 'soccer ball',
2183
+ 'makeup',
2184
+ 'knight',
2185
+ 'bowling ball',
2186
+ 'shower curtain',
2187
+ 'flame',
2188
+ 'lightning',
2189
+ 'running',
2190
+ 'power plant',
2191
+ 'crib',
2192
+ 'cartoon',
2193
+ 'moat',
2194
+ 'fashion girl',
2195
+ 'wedding invitation',
2196
+ 'bottle',
2197
+ 'cliff',
2198
+ 'monastery',
2199
+ 'file photo',
2200
+ 'apartment',
2201
+ 'casino',
2202
+ 'cream',
2203
+ 'sweatshirt',
2204
+ 'storm',
2205
+ 'cruise',
2206
+ 'teddy bear',
2207
+ 'shovel',
2208
+ 'wind farm',
2209
+ 'writer',
2210
+ 'dock',
2211
+ 'professional',
2212
+ 'hotel room',
2213
+ 'job',
2214
+ 'monitor',
2215
+ 'donkey',
2216
+ 'pass',
2217
+ 'interview',
2218
+ 'duchess',
2219
+ 'mark',
2220
+ 'plank',
2221
+ 'beard',
2222
+ 'zombie',
2223
+ 'trio',
2224
+ 'channel',
2225
+ 'cricket team',
2226
+ 'windmill',
2227
+ 'vest',
2228
+ 'diagram',
2229
+ 'cable',
2230
+ 'winter scene',
2231
+ 'golden gate bridge',
2232
+ 'buffalo',
2233
+ 'studio portrait',
2234
+ 'pagoda',
2235
+ 'whiskey',
2236
+ 'freight train',
2237
+ 'kite',
2238
+ 'future',
2239
+ 'steam train',
2240
+ 'phone box',
2241
+ 'headset',
2242
+ 'wood',
2243
+ 'snowboarder',
2244
+ 'paper bag',
2245
+ 'slide',
2246
+ 'grapefruit',
2247
+ 'seating',
2248
+ 'morning',
2249
+ 'bronze sculpture',
2250
+ 'theatre actor',
2251
+ 'stump',
2252
+ 'jean',
2253
+ 'landmark',
2254
+ 'jam',
2255
+ 'waist',
2256
+ 'watercolor',
2257
+ 'hammock',
2258
+ 'light fixture',
2259
+ 'ice',
2260
+ 'basin',
2261
+ 'beverage',
2262
+ 'shelter',
2263
+ 'premiere',
2264
+ 'mound',
2265
+ 'ear',
2266
+ 'bronze',
2267
+ 'sunlight',
2268
+ 'street',
2269
+ 'energy',
2270
+ 'barn door',
2271
+ 'hike',
2272
+ 'fleet',
2273
+ 'claw',
2274
+ 'beach',
2275
+ 'pepperoni',
2276
+ 'bin',
2277
+ 'trainer',
2278
+ 'buffet',
2279
+ 'archive',
2280
+ 'toddler',
2281
+ 'referee',
2282
+ 'bay window',
2283
+ 'dove',
2284
+ 'production company',
2285
+ 'evening light',
2286
+ 'gate',
2287
+ 'farm',
2288
+ 'reed',
2289
+ 'fruit stand',
2290
+ 'explorer',
2291
+ 'snow storm',
2292
+ 'throw pillow',
2293
+ 'button',
2294
+ 'display case',
2295
+ 'bookcase',
2296
+ 'lead',
2297
+ 'lipstick',
2298
+ 'basketball court',
2299
+ 'cargo',
2300
+ 'ensemble',
2301
+ 'pope',
2302
+ 'clock tower',
2303
+ 'teen',
2304
+ 'speaker',
2305
+ 'rat',
2306
+ 'laptop',
2307
+ 'ski',
2308
+ 'mess',
2309
+ 'stadium',
2310
+ 'ferry boat',
2311
+ 'bunny',
2312
+ 'waterfront',
2313
+ 'downtown',
2314
+ 'sink',
2315
+ 'press conference',
2316
+ 'dinner',
2317
+ 'condiment',
2318
+ 'thread',
2319
+ 'audience',
2320
+ 'grid',
2321
+ 'car',
2322
+ 'plastic',
2323
+ 'people',
2324
+ 'barbecue',
2325
+ 'pigeon',
2326
+ 'urinal',
2327
+ 'seagull',
2328
+ 'volunteer',
2329
+ 'hockey',
2330
+ 'fir tree',
2331
+ 'pollution',
2332
+ 'trial',
2333
+ 'collar',
2334
+ 'area',
2335
+ 'meeting room',
2336
+ 'circus',
2337
+ 'yogurt',
2338
+ 'orangutan',
2339
+ 'viaduct',
2340
+ 'comedian',
2341
+ 'drone',
2342
+ 'scissor',
2343
+ 'pop rock artist',
2344
+ 'biscuit',
2345
+ 'panda',
2346
+ 'water feature',
2347
+ 'air balloon',
2348
+ 'remote control',
2349
+ 'watercolor painting',
2350
+ 'show',
2351
+ 'walk',
2352
+ 'post office',
2353
+ 'bike path',
2354
+ 'rap gangsta artist',
2355
+ 'microphone',
2356
+ 'crack',
2357
+ 'sunset sky',
2358
+ 'glass',
2359
+ 'tv show',
2360
+ 'cartoon style',
2361
+ 'stripe',
2362
+ 'foyer',
2363
+ 'signal',
2364
+ 'calligraphy',
2365
+ 'bulb',
2366
+ 'gardener',
2367
+ 'coffee bean',
2368
+ 'spider',
2369
+ 'tapestry',
2370
+ 'city skyline',
2371
+ 'necklace',
2372
+ 'kitten',
2373
+ 'traveler',
2374
+ 'veteran',
2375
+ 'frosting',
2376
+ 'fry',
2377
+ 'tennis court',
2378
+ 'tank top',
2379
+ 'butterfly house',
2380
+ 'mist',
2381
+ 'drummer',
2382
+ 'water level',
2383
+ 'scale',
2384
+ 'baseball glove',
2385
+ 'music video performer',
2386
+ 'champagne',
2387
+ 'camping',
2388
+ 'clothing',
2389
+ 'water drop',
2390
+ 'telephone box',
2391
+ 'pen',
2392
+ 'morning mist',
2393
+ 'fire engine',
2394
+ 'porch',
2395
+ 'opening ceremony',
2396
+ 'style',
2397
+ 'palm tree',
2398
+ 'fashion show',
2399
+ 'universe',
2400
+ 'scratch',
2401
+ 'axe',
2402
+ 'ottoman',
2403
+ 'explosion',
2404
+ 'rib',
2405
+ 'boutique',
2406
+ 'game',
2407
+ 'cucumber',
2408
+ 'fruit',
2409
+ 'stone bridge',
2410
+ 'nature reserve',
2411
+ 'track',
2412
+ 'train window',
2413
+ 'punch',
2414
+ 'telephone pole',
2415
+ 'velvet',
2416
+ 'sauce',
2417
+ 'moon',
2418
+ 'contrast',
2419
+ 'flamingo',
2420
+ 'bat',
2421
+ 'vending machine',
2422
+ 'ship',
2423
+ 'equestrian',
2424
+ 'shade',
2425
+ 'comforter',
2426
+ 'pallet',
2427
+ 'sparrow',
2428
+ 'wii',
2429
+ 'glaze',
2430
+ 'grocery',
2431
+ 'steeple',
2432
+ 'soccer player',
2433
+ 'contract',
2434
+ 'advertising',
2435
+ 'runner',
2436
+ 'chimpanzee',
2437
+ 'world',
2438
+ 'seat',
2439
+ 'project',
2440
+ 'chihuahua',
2441
+ 'bubble',
2442
+ 'willow',
2443
+ 'pedestal',
2444
+ 'soul hip hop artist',
2445
+ 'curb',
2446
+ 'drawer',
2447
+ 'leaf',
2448
+ 'banner',
2449
+ 'launch party',
2450
+ 'coach',
2451
+ 'government',
2452
+ 'snowball',
2453
+ 'toy',
2454
+ 'portrait',
2455
+ 'doctor',
2456
+ 'whiteboard',
2457
+ 'electronic',
2458
+ 'tiger',
2459
+ 'graffiti',
2460
+ 'column',
2461
+ 'nightstand',
2462
+ 'whistle',
2463
+ 'maxi dress',
2464
+ 'bench',
2465
+ 'wetsuit',
2466
+ 'bird feeder',
2467
+ 'football game',
2468
+ 'basketball',
2469
+ 'class',
2470
+ 'bathroom door',
2471
+ 'store window',
2472
+ 'text message',
2473
+ 'wreath',
2474
+ 'street view',
2475
+ 'binocular',
2476
+ 'pet',
2477
+ 'facade',
2478
+ 'drought',
2479
+ 'lemon',
2480
+ 'new year',
2481
+ 'night view',
2482
+ 'airplane window',
2483
+ 'specie',
2484
+ 'rule',
2485
+ 'jaw',
2486
+ 'wheat field',
2487
+ 'diet',
2488
+ 'pop artist',
2489
+ 'habitat',
2490
+ 'screenshot',
2491
+ 'scoreboard',
2492
+ 'shore',
2493
+ 'mane',
2494
+ 'quilt',
2495
+ 'ski lift',
2496
+ 'orchid',
2497
+ 'turban',
2498
+ 'christmas',
2499
+ 'airport',
2500
+ 'marina',
2501
+ 'glass door',
2502
+ 'glass bottle',
2503
+ 'restaurant',
2504
+ 'conductor',
2505
+ 'logo',
2506
+ 'sleep',
2507
+ 'tape',
2508
+ 'tomato',
2509
+ 'river bank',
2510
+ 'lilac',
2511
+ 'tooth',
2512
+ 'training',
2513
+ 'pottery',
2514
+ 'shop',
2515
+ 'steam engine',
2516
+ 'mason jar',
2517
+ 'base',
2518
+ 'procession',
2519
+ 'border',
2520
+ 'shoot',
2521
+ 'footprint',
2522
+ 'hotdog',
2523
+ 'bull',
2524
+ 'stocking',
2525
+ 'recreation',
2526
+ 'automobile model',
2527
+ 'design',
2528
+ 'country pop artist',
2529
+ 'river',
2530
+ 'retriever',
2531
+ 'department store',
2532
+ 'auditorium',
2533
+ 'sport car',
2534
+ 'supermarket',
2535
+ 'belt',
2536
+ 'cricket',
2537
+ 'window box',
2538
+ 'dress shirt',
2539
+ 'letter',
2540
+ 'residence',
2541
+ 'megaphone',
2542
+ 'pant',
2543
+ 'wildfire',
2544
+ 'bird nest',
2545
+ 'crab',
2546
+ 'swimsuit',
2547
+ 'candle',
2548
+ 'funeral',
2549
+ 'mill',
2550
+ 'national park',
2551
+ 'plant',
2552
+ 'cop',
2553
+ 'power line',
2554
+ 'perch',
2555
+ 'blue',
2556
+ 'finger',
2557
+ 'ferris wheel',
2558
+ 'globe',
2559
+ 'skateboard',
2560
+ 'helmet',
2561
+ 'movie theater',
2562
+ 'uniform',
2563
+ 'hammer',
2564
+ 'material',
2565
+ 'kid',
2566
+ 'well',
2567
+ 'butterfly',
2568
+ 'sideline',
2569
+ 'fashion fall show',
2570
+ 'planet earth',
2571
+ 'lift',
2572
+ 'male',
2573
+ 'sauna',
2574
+ 'gray',
2575
+ 'flour',
2576
+ 'sand sculpture',
2577
+ 'program',
2578
+ 'cabinet',
2579
+ 'infant',
2580
+ 'wheel',
2581
+ 'aircraft model',
2582
+ 'dough',
2583
+ 'garlic',
2584
+ 'skate',
2585
+ 'arrow',
2586
+ 'wrapping paper',
2587
+ 'ripple',
2588
+ 'lamp',
2589
+ 'iron',
2590
+ 'banknote',
2591
+ 'beaver',
2592
+ 'ferry',
2593
+ 'courtyard',
2594
+ 'bassist',
2595
+ 'countryside',
2596
+ 'steak',
2597
+ 'comfort',
2598
+ 'boxer',
2599
+ 'laundry room',
2600
+ 'campsite',
2601
+ 'brick building',
2602
+ 'golf',
2603
+ 'subway',
2604
+ 'headphone',
2605
+ 'fort',
2606
+ 'handbag',
2607
+ 'drum',
2608
+ 'flood',
2609
+ 'saddle',
2610
+ 'bass',
2611
+ 'labyrinth',
2612
+ 'needle',
2613
+ 'sun ray',
2614
+ 'app',
2615
+ 'menu',
2616
+ 'president',
2617
+ 'cardigan',
2618
+ 'dandelion',
2619
+ 'wetland',
2620
+ 'ice hockey player',
2621
+ 'number',
2622
+ 'city hall',
2623
+ 'fishing',
2624
+ 'portrait session',
2625
+ 'pug',
2626
+ 'key',
2627
+ 'art print',
2628
+ 'minister',
2629
+ 'hurdle',
2630
+ 'emergency',
2631
+ 'painting artist',
2632
+ 'flag pole',
2633
+ 'evening',
2634
+ 'purse',
2635
+ 'recipe',
2636
+ 'golf ball',
2637
+ 'coloring book',
2638
+ 'mountain peak',
2639
+ 'senior',
2640
+ 'holiday',
2641
+ 'bud',
2642
+ 'cousin',
2643
+ 'pantry',
2644
+ 'lap',
2645
+ 'skin',
2646
+ 'flag',
2647
+ 'tissue paper',
2648
+ 'ridge',
2649
+ 'wire fence',
2650
+ 'surfer',
2651
+ 'climber',
2652
+ 'photograph',
2653
+ 'sewing machine',
2654
+ 'cooler',
2655
+ 'actress',
2656
+ 'apple tree',
2657
+ 'cancer',
2658
+ 'starfish',
2659
+ 'automobile make',
2660
+ 'dumbbell',
2661
+ 'brace',
2662
+ 'tunnel',
2663
+ 'window',
2664
+ 'paint artist',
2665
+ 'composition',
2666
+ 'school student',
2667
+ 'condo',
2668
+ 'convertible',
2669
+ 'cushion',
2670
+ 'selfie',
2671
+ 'territory',
2672
+ 'guide',
2673
+ 'tree',
2674
+ 'court',
2675
+ 'shrimp',
2676
+ 'stone house',
2677
+ 'dress',
2678
+ 'eyelash',
2679
+ 'juice',
2680
+ 'broccoli',
2681
+ 'chain',
2682
+ 'tourism',
2683
+ 'mountain top',
2684
+ 'concept car',
2685
+ 'film premiere',
2686
+ 'light bulb',
2687
+ 'cafeteria',
2688
+ 'badge',
2689
+ 'flower bed',
2690
+ 'theater',
2691
+ 'root',
2692
+ 'racecar driver',
2693
+ 'basketball boy game',
2694
+ 'glove',
2695
+ 'skyline',
2696
+ 'wall',
2697
+ 'glacier',
2698
+ 'airport terminal',
2699
+ 'bug',
2700
+ 'trim',
2701
+ 'railway station',
2702
+ 'briefcase',
2703
+ 'flat',
2704
+ 'fountain',
2705
+ 'person',
2706
+ 'lane',
2707
+ 'asparagus',
2708
+ 'art',
2709
+ 'lantern',
2710
+ 'dishwasher',
2711
+ 'director',
2712
+ 'snake',
2713
+ 'lecture',
2714
+ 'game controller',
2715
+ 'tree branch',
2716
+ 'pub',
2717
+ 'bathing suit',
2718
+ 'queue',
2719
+ 'belly',
2720
+ 'poppy',
2721
+ 'bow',
2722
+ 'pitcher',
2723
+ 'ice cream cone',
2724
+ 'cave',
2725
+ 'candy',
2726
+ 'road bridge',
2727
+ 'host',
2728
+ 'traffic jam',
2729
+ 'earring',
2730
+ 'file',
2731
+ 'foot',
2732
+ 'watermark overlay stamp',
2733
+ 'mailbox',
2734
+ 'supercar',
2735
+ 'railing',
2736
+ 'bedroom',
2737
+ 'seafood',
2738
+ 'waffle',
2739
+ 'bronze statue',
2740
+ 'plan',
2741
+ 'flow',
2742
+ 'marble',
2743
+ 'basketball game',
2744
+ 'automobile',
2745
+ 'scene',
2746
+ 'cypress tree',
2747
+ 'soldier',
2748
+ 'skateboarder',
2749
+ 'glass building',
2750
+ 'cherry tree',
2751
+ 'pump',
2752
+ 'grain',
2753
+ 'wildebeest',
2754
+ 'loop',
2755
+ 'frame',
2756
+ 'bathtub',
2757
+ 'saxophone',
2758
+ 'diver',
2759
+ 'stalk',
2760
+ 'lily',
2761
+ 'bead',
2762
+ 'alley',
2763
+ 'flock',
2764
+ 'family room',
2765
+ 'manufacturing',
2766
+ 'pointer',
2767
+ 'worker',
2768
+ 'navy',
2769
+ 'potato',
2770
+ 'teacher',
2771
+ 'photography',
2772
+ 'dolly',
2773
+ 'boardwalk',
2774
+ 'water fountain',
2775
+ 'athlete',
2776
+ 'side dish',
2777
+ 'bay',
2778
+ 'ice hockey',
2779
+ 'phone',
2780
+ 'hero',
2781
+ 'face',
2782
+ 'gold medal',
2783
+ 'blind',
2784
+ 'swamp',
2785
+ 'researcher',
2786
+ 'swim',
2787
+ 'meatball',
2788
+ 'iguana',
2789
+ 'leather jacket',
2790
+ 'jellyfish',
2791
+ 'site',
2792
+ 'smoke',
2793
+ 'traffic signal',
2794
+ 'melon',
2795
+ 'beetle',
2796
+ 'calculator',
2797
+ 'skirt',
2798
+ 'plantation',
2799
+ 'sculptor',
2800
+ 'barrier',
2801
+ 'catcher',
2802
+ 'security guard',
2803
+ 'sketch',
2804
+ 'awning',
2805
+ 'steering wheel',
2806
+ 'mountain view',
2807
+ 'bus stop',
2808
+ 'pool',
2809
+ 'leg',
2810
+ 'spotlight',
2811
+ 'apron',
2812
+ 'mineral',
2813
+ 'inlet',
2814
+ 'sleeve',
2815
+ 'torch',
2816
+ 'emotion',
2817
+ 'march',
2818
+ 'police officer',
2819
+ 'performance',
2820
+ 'lamp post',
2821
+ 'fishing boat',
2822
+ 'summer',
2823
+ 'presentation',
2824
+ 'saucer',
2825
+ 'suitcase',
2826
+ 'supermodel',
2827
+ 'goalkeeper',
2828
+ 'shrub',
2829
+ 'rock artist',
2830
+ 'document',
2831
+ 'beach house',
2832
+ 'man',
2833
+ 'blue artist',
2834
+ 'cigar',
2835
+ 'railroad track',
2836
+ 'gown',
2837
+ 'mosaic',
2838
+ 'bungalow',
2839
+ 'alphabet',
2840
+ 'baseball field',
2841
+ 'shed',
2842
+ 'pedestrian',
2843
+ 'rail',
2844
+ 'soap',
2845
+ 'kitchen counter',
2846
+ 'dessert',
2847
+ 'dunk',
2848
+ 'blossom',
2849
+ 'conversation',
2850
+ 'fruit market',
2851
+ 'glass jar',
2852
+ 'military',
2853
+ 'beer bottle',
2854
+ 'photographer',
2855
+ 'tennis racket',
2856
+ 'competition',
2857
+ 'escalator',
2858
+ 'bell tower',
2859
+ 'stilt',
2860
+ 'ballerina',
2861
+ 'television',
2862
+ 'feather',
2863
+ 'fence post',
2864
+ 'rear',
2865
+ 'dahlia',
2866
+ 'red carpet',
2867
+ 'tub',
2868
+ 'hole',
2869
+ 'fortress',
2870
+ 'pack',
2871
+ 'telephone',
2872
+ 'cardboard',
2873
+ 'city park',
2874
+ 'platform',
2875
+ 'college student',
2876
+ 'arch bridge',
2877
+ 'wind',
2878
+ 'blender',
2879
+ 'bloom',
2880
+ 'ice rink',
2881
+ 'birthday',
2882
+ 'raven',
2883
+ 'fairy',
2884
+ 'embankment',
2885
+ 'hall',
2886
+ 'flower shop',
2887
+ 'suburb',
2888
+ 'barrel',
2889
+ 'biker',
2890
+ 'steam',
2891
+ 'dragonfly',
2892
+ 'formation',
2893
+ 'electricity',
2894
+ 'business people',
2895
+ 'symmetry',
2896
+ 'walkway',
2897
+ 'fisherman',
2898
+ 'gas mask',
2899
+ 'loch',
2900
+ 'youth',
2901
+ 'hanger',
2902
+ 'dot',
2903
+ 'fish',
2904
+ 'street market',
2905
+ 'animation film',
2906
+ 'crime fiction film',
2907
+ 'boar',
2908
+ 'emblem',
2909
+ 'halloween costume',
2910
+ 'kangaroo',
2911
+ 'couple',
2912
+ 'spoon',
2913
+ 'squirrel',
2914
+ 'neon sign',
2915
+ 'sky',
2916
+ 'office desk',
2917
+ 'beauty salon',
2918
+ 'breakwater',
2919
+ 'fashion look',
2920
+ 'toaster',
2921
+ 'author',
2922
+ 'news conference',
2923
+ 'outdoor',
2924
+ 'canoe',
2925
+ 'dragon',
2926
+ 'tool',
2927
+ 'shopping centre',
2928
+ 'ladybug',
2929
+ 'swimming pool',
2930
+ 'landscaping',
2931
+ 'ski pole',
2932
+ 'red',
2933
+ 'truck',
2934
+ 'fly',
2935
+ 'temple',
2936
+ 'level',
2937
+ 'sunday',
2938
+ 'railroad bridge',
2939
+ 'car mirror',
2940
+ 'lawn mower',
2941
+ 'flute',
2942
+ 'aircraft carrier',
2943
+ 'fashion menswear london week',
2944
+ 'sunshine',
2945
+ 'tile floor',
2946
+ 'skull',
2947
+ 'fossil',
2948
+ 'flower arrangement',
2949
+ 'diaper',
2950
+ 'sea turtle',
2951
+ 'cherry blossom',
2952
+ 'fireman',
2953
+ 'shack',
2954
+ 'lens',
2955
+ 'waiter',
2956
+ 'animal',
2957
+ 'basement',
2958
+ 'snow',
2959
+ 'autumn park',
2960
+ 'glass box',
2961
+ 'kick',
2962
+ 'head',
2963
+ 'anniversary',
2964
+ 'vine',
2965
+ 'back',
2966
+ 'paper lantern',
2967
+ 'fish tank',
2968
+ 'cellphone',
2969
+ 'silk',
2970
+ 'coral',
2971
+ 'notebook',
2972
+ 'photo',
2973
+ 'gazebo',
2974
+ 'ketchup',
2975
+ 'driver',
2976
+ 'farmer',
2977
+ 'bonfire',
2978
+ 'chestnut',
2979
+ 'photoshoot',
2980
+ 'football field',
2981
+ 'olive tree',
2982
+ 'pheasant',
2983
+ 'sandal',
2984
+ 'toilet',
2985
+ 'fireplace',
2986
+ 'music',
2987
+ 'deity',
2988
+ 'fish market',
2989
+ 'fig',
2990
+ 'bell',
2991
+ 'neck',
2992
+ 'grave',
2993
+ 'villa',
2994
+ 'cyclist',
2995
+ 'crate',
2996
+ 'grey',
2997
+ 'asphalt road',
2998
+ 'soccer',
2999
+ 'hostel',
3000
+ 'municipality',
3001
+ 'courthouse',
3002
+ 'roof',
3003
+ 'end table',
3004
+ 'pot',
3005
+ 'sedan',
3006
+ 'structure',
3007
+ 'folk artist',
3008
+ 'sport',
3009
+ 'sport team',
3010
+ 'protest',
3011
+ 'syringe',
3012
+ 'fashion designer',
3013
+ 'jersey',
3014
+ 'heart shape',
3015
+ 'kayak',
3016
+ 'stare',
3017
+ 'sit with',
3018
+ 'direct',
3019
+ 'read',
3020
+ 'photograph',
3021
+ 'spin',
3022
+ 'teach',
3023
+ 'laugh',
3024
+ 'carve',
3025
+ 'grow on',
3026
+ 'warm',
3027
+ 'watch',
3028
+ 'stretch',
3029
+ 'smell',
3030
+ 'decorate',
3031
+ 'shine',
3032
+ 'light',
3033
+ 'dance',
3034
+ 'send',
3035
+ 'park',
3036
+ 'chase',
3037
+ 'collect',
3038
+ 'lead',
3039
+ 'kiss',
3040
+ 'lead to',
3041
+ 'lick',
3042
+ 'smile',
3043
+ 'cheer',
3044
+ 'sit',
3045
+ 'point',
3046
+ 'block',
3047
+ 'rock',
3048
+ 'drop',
3049
+ 'cut',
3050
+ 'ski',
3051
+ 'wrap',
3052
+ 'lose',
3053
+ 'serve',
3054
+ 'provide',
3055
+ 'sleep',
3056
+ 'dress',
3057
+ 'embrace',
3058
+ 'burn',
3059
+ 'pack',
3060
+ 'stir',
3061
+ 'create',
3062
+ 'touch',
3063
+ 'wash',
3064
+ 'stick',
3065
+ 'reveal',
3066
+ 'shop',
3067
+ 'train',
3068
+ 'paint',
3069
+ 'groom',
3070
+ 'hunt',
3071
+ 'bloom',
3072
+ 'play',
3073
+ 'pay',
3074
+ 'brush',
3075
+ 'shoot',
3076
+ 'hold',
3077
+ 'picture',
3078
+ 'carry',
3079
+ 'sip',
3080
+ 'contain',
3081
+ 'turn',
3082
+ 'pour',
3083
+ 'pitch',
3084
+ 'give',
3085
+ 'add',
3086
+ 'blow',
3087
+ 'look in',
3088
+ 'show',
3089
+ 'walk',
3090
+ 'illuminate',
3091
+ 'kneel',
3092
+ 'cover',
3093
+ 'drag',
3094
+ 'post',
3095
+ 'present',
3096
+ 'fit',
3097
+ 'operate',
3098
+ 'fish',
3099
+ 'race',
3100
+ 'write',
3101
+ 'deliver',
3102
+ 'peel',
3103
+ 'push',
3104
+ 'run',
3105
+ 'sit around',
3106
+ 'buy',
3107
+ 'jump',
3108
+ 'walk on',
3109
+ 'attend',
3110
+ 'clean',
3111
+ 'sell',
3112
+ 'ride on',
3113
+ 'mount',
3114
+ 'host',
3115
+ 'dry',
3116
+ 'plant',
3117
+ 'sing',
3118
+ 'row',
3119
+ 'shake',
3120
+ 'perch',
3121
+ 'ride',
3122
+ 'fight',
3123
+ 'skateboard',
3124
+ 'live',
3125
+ 'call',
3126
+ 'surround',
3127
+ 'practice',
3128
+ 'play on',
3129
+ 'work on',
3130
+ 'step',
3131
+ 'relax',
3132
+ 'hit',
3133
+ 'fall in',
3134
+ 'flow',
3135
+ 'greet',
3136
+ 'launch',
3137
+ 'wear',
3138
+ 'hang on',
3139
+ 'drive',
3140
+ 'sit in',
3141
+ 'break',
3142
+ 'learn',
3143
+ 'fly',
3144
+ 'connect',
3145
+ 'display',
3146
+ 'locate',
3147
+ 'compete',
3148
+ 'go for',
3149
+ 'sail',
3150
+ 'lift',
3151
+ 'toast',
3152
+ 'help',
3153
+ 'run on',
3154
+ 'reflect',
3155
+ 'pose',
3156
+ 'scratch',
3157
+ 'frame',
3158
+ 'dribble',
3159
+ 'herd',
3160
+ 'enter',
3161
+ 'exit',
3162
+ 'place',
3163
+ 'inspect',
3164
+ 'build',
3165
+ 'pick',
3166
+ 'fill',
3167
+ 'grind',
3168
+ 'skate',
3169
+ 'offer',
3170
+ 'float',
3171
+ 'sit by',
3172
+ 'stand',
3173
+ 'release',
3174
+ 'rest',
3175
+ 'singe',
3176
+ 'climb',
3177
+ 'tie',
3178
+ 'mark',
3179
+ 'lay',
3180
+ 'stand around',
3181
+ 'capture',
3182
+ 'set',
3183
+ 'land',
3184
+ 'swinge',
3185
+ 'run in',
3186
+ 'kick',
3187
+ 'lean',
3188
+ 'head',
3189
+ 'sign',
3190
+ 'approach',
3191
+ 'swim',
3192
+ 'close',
3193
+ 'crash',
3194
+ 'control',
3195
+ 'fall',
3196
+ 'remove',
3197
+ 'repair',
3198
+ 'open',
3199
+ 'appear',
3200
+ 'travel',
3201
+ 'load',
3202
+ 'miss',
3203
+ 'check',
3204
+ 'surf',
3205
+ 'moor',
3206
+ 'smoke',
3207
+ 'drink',
3208
+ 'board',
3209
+ 'seat',
3210
+ 'feed',
3211
+ 'rise',
3212
+ 'sit on',
3213
+ 'swing',
3214
+ 'grow',
3215
+ 'strike',
3216
+ 'date',
3217
+ 'slide',
3218
+ 'share',
3219
+ 'graze',
3220
+ 'jump in',
3221
+ 'lie',
3222
+ 'extrude',
3223
+ 'roll',
3224
+ 'move',
3225
+ 'gather',
3226
+ 'eat',
3227
+ 'pull',
3228
+ 'run through',
3229
+ 'squeeze',
3230
+ 'lay on',
3231
+ 'draw',
3232
+ 'play with',
3233
+ 'wave',
3234
+ 'assemble',
3235
+ 'perform',
3236
+ 'march',
3237
+ 'score',
3238
+ 'attach',
3239
+ 'adjust',
3240
+ 'hang',
3241
+ 'hug',
3242
+ 'sleep on',
3243
+ 'throw',
3244
+ 'live in',
3245
+ 'talk',
3246
+ 'pet',
3247
+ 'work',
3248
+ 'run with',
3249
+ 'see',
3250
+ 'flip',
3251
+ 'catch',
3252
+ 'cook',
3253
+ 'receive',
3254
+ 'celebrate',
3255
+ 'look',
3256
+ 'classic',
3257
+ 'bridal',
3258
+ 'indoor',
3259
+ 'industrial',
3260
+ 'teenage',
3261
+ 'mini',
3262
+ 'grassy',
3263
+ 'aged',
3264
+ 'long',
3265
+ 'warm',
3266
+ 'light',
3267
+ 'handsome',
3268
+ 'happy',
3269
+ 'three',
3270
+ 'pregnant',
3271
+ 'circular',
3272
+ 'urban',
3273
+ 'silver',
3274
+ 'ceramic',
3275
+ '3d',
3276
+ 'green',
3277
+ 'blonde',
3278
+ 'golden',
3279
+ 'dark',
3280
+ 'tropical',
3281
+ 'ripe',
3282
+ 'deep',
3283
+ 'fat',
3284
+ 'musical',
3285
+ 'giant',
3286
+ 'medical',
3287
+ 'medieval',
3288
+ 'bare',
3289
+ 'stunning',
3290
+ 'bold',
3291
+ 'geographical',
3292
+ 'huge',
3293
+ 'plastic',
3294
+ 'foggy',
3295
+ 'stormy',
3296
+ 'gothic',
3297
+ 'biological',
3298
+ 'empty',
3299
+ 'clear',
3300
+ 'antique',
3301
+ 'pink',
3302
+ 'steep',
3303
+ 'brown',
3304
+ 'striped',
3305
+ 'aerial',
3306
+ 'rainy',
3307
+ 'cool',
3308
+ 'flying',
3309
+ 'commercial',
3310
+ 'purple',
3311
+ 'trendy',
3312
+ 'blank',
3313
+ 'haired',
3314
+ 'dead',
3315
+ 'wooden',
3316
+ 'flat',
3317
+ 'high',
3318
+ 'beige',
3319
+ 'panoramic',
3320
+ 'angry',
3321
+ 'dozen',
3322
+ 'rural',
3323
+ 'solar',
3324
+ 'big',
3325
+ 'small',
3326
+ 'stained',
3327
+ 'thick',
3328
+ 'many',
3329
+ 'fresh',
3330
+ 'clean',
3331
+ 'strong',
3332
+ 'abstract',
3333
+ 'crowded',
3334
+ 'retro',
3335
+ 'dry',
3336
+ 'gorgeous',
3337
+ 'martial',
3338
+ 'modern',
3339
+ 'blue',
3340
+ 'cloudy',
3341
+ 'low',
3342
+ 'four',
3343
+ 'outdoor',
3344
+ 'single',
3345
+ 'much',
3346
+ 'beautiful',
3347
+ 'snowy',
3348
+ 'pretty',
3349
+ 'new',
3350
+ 'short',
3351
+ 'sunny',
3352
+ 'closed',
3353
+ 'rocky',
3354
+ 'red',
3355
+ 'two',
3356
+ 'double',
3357
+ 'male',
3358
+ 'gray',
3359
+ 'five',
3360
+ 'colorful',
3361
+ 'automotive',
3362
+ 'various',
3363
+ 'one',
3364
+ 'old',
3365
+ 'rusty',
3366
+ 'tall',
3367
+ 'wild',
3368
+ 'narrow',
3369
+ 'natural',
3370
+ 'several',
3371
+ 'frozen',
3372
+ 'textured',
3373
+ 'lush',
3374
+ 'young',
3375
+ 'hot',
3376
+ 'mixed',
3377
+ 'white',
3378
+ 'float',
3379
+ 'quiet',
3380
+ 'round',
3381
+ 'bright',
3382
+ 'religious',
3383
+ 'female',
3384
+ 'historical',
3385
+ 'shiny',
3386
+ 'traditional',
3387
+ 'tourist',
3388
+ 'yellow',
3389
+ 'bald',
3390
+ 'coastal',
3391
+ 'lovely',
3392
+ 'little',
3393
+ 'broken',
3394
+ 'romantic',
3395
+ 'wide',
3396
+ 'royal',
3397
+ 'rich',
3398
+ 'open',
3399
+ 'cute',
3400
+ 'ancient',
3401
+ 'cold',
3402
+ 'political',
3403
+ 'elderly',
3404
+ 'gold',
3405
+ 'full',
3406
+ 'rustic',
3407
+ 'metallic',
3408
+ 'floral',
3409
+ 'sad',
3410
+ 'wet',
3411
+ 'fancy',
3412
+ 'senior',
3413
+ 'tiny',
3414
+ 'stylish',
3415
+ 'large',
3416
+ 'frosty',
3417
+ 'orange',
3418
+ 'transparent',
3419
+ 'electronic',
3420
+ 'shallow',
3421
+ 'scared',
3422
+ 'armed',
3423
+ 'dirty',
3424
+ 'historic',
3425
+ 'black',
3426
+ 'few',
3427
+ 'windy',
3428
+ 'some',
3429
+ 'square',
3430
+ 'ornamental',
3431
+ 'sandy',
3432
+ 'thin']
3433
+
3434
+
3435
+ tra_array = np.array(tra_array)
3436
+
3437
+
gradio_demo.ipynb ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "35d8939e-909d-45d8-bcf9-0ff1dccacfdf",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/opt/conda/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14
+ " from .autonotebook import tqdm as notebook_tqdm\n",
15
+ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['bert.encoder.layer.2.attention.self.key.bias', 'cls.seq_relationship.weight', 'bert.encoder.layer.5.intermediate.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'bert.encoder.layer.7.output.dense.weight', 'bert.encoder.layer.10.output.LayerNorm.bias', 'bert.encoder.layer.2.intermediate.dense.bias', 'bert.encoder.layer.6.attention.self.value.weight', 'bert.encoder.layer.5.attention.self.query.bias', 'bert.encoder.layer.8.intermediate.dense.bias', 'bert.encoder.layer.4.output.dense.bias', 'bert.encoder.layer.8.attention.output.dense.weight', 'bert.encoder.layer.8.attention.self.query.bias', 'bert.encoder.layer.4.attention.output.dense.weight', 'bert.encoder.layer.7.intermediate.dense.weight', 'bert.encoder.layer.5.attention.output.LayerNorm.bias', 'bert.encoder.layer.8.output.LayerNorm.bias', 'bert.encoder.layer.2.output.LayerNorm.bias', 'bert.encoder.layer.3.attention.self.value.weight', 'bert.encoder.layer.2.intermediate.dense.weight', 'bert.encoder.layer.5.attention.output.dense.bias', 'bert.encoder.layer.11.intermediate.dense.weight', 'cls.predictions.transform.dense.weight', 'bert.encoder.layer.4.attention.self.key.bias', 'bert.encoder.layer.2.attention.output.LayerNorm.bias', 'bert.encoder.layer.7.output.LayerNorm.bias', 'bert.encoder.layer.5.intermediate.dense.bias', 'bert.encoder.layer.10.output.dense.weight', 'bert.encoder.layer.10.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.intermediate.dense.weight', 'bert.encoder.layer.3.attention.self.query.bias', 'bert.encoder.layer.11.attention.self.query.bias', 'bert.encoder.layer.7.attention.self.value.bias', 'bert.encoder.layer.6.output.dense.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.attention.output.LayerNorm.weight', 'cls.predictions.bias', 'bert.encoder.layer.10.attention.output.dense.weight', 'bert.encoder.layer.8.attention.self.value.weight', 'cls.predictions.transform.dense.bias', 'bert.encoder.layer.11.attention.self.query.weight', 'bert.encoder.layer.8.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.value.weight', 'bert.encoder.layer.2.attention.self.key.weight', 'bert.encoder.layer.3.attention.output.LayerNorm.weight', 'bert.encoder.layer.8.attention.output.LayerNorm.weight', 'bert.encoder.layer.8.attention.output.LayerNorm.bias', 'bert.encoder.layer.2.output.dense.weight', 'bert.encoder.layer.3.attention.output.dense.bias', 'bert.encoder.layer.11.attention.output.dense.weight', 'bert.encoder.layer.10.attention.self.value.weight', 'bert.encoder.layer.7.attention.output.dense.bias', 'bert.encoder.layer.11.output.dense.bias', 'bert.pooler.dense.bias', 'bert.encoder.layer.11.attention.self.value.bias', 'bert.encoder.layer.6.attention.self.query.bias', 'bert.encoder.layer.6.output.dense.weight', 'bert.encoder.layer.9.output.LayerNorm.bias', 'bert.encoder.layer.4.output.LayerNorm.weight', 'bert.encoder.layer.9.output.LayerNorm.weight', 'bert.encoder.layer.9.intermediate.dense.bias', 'cls.predictions.decoder.weight', 'bert.encoder.layer.4.attention.output.dense.bias', 'bert.encoder.layer.4.attention.self.value.weight', 'bert.encoder.layer.7.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.key.bias', 'bert.encoder.layer.6.attention.output.dense.weight', 'bert.encoder.layer.7.attention.self.key.weight', 'bert.encoder.layer.6.attention.output.dense.bias', 'bert.encoder.layer.10.attention.self.value.bias', 'cls.seq_relationship.bias', 'bert.encoder.layer.3.attention.self.key.weight', 'bert.encoder.layer.10.attention.self.key.bias', 'bert.encoder.layer.9.attention.output.dense.bias', 'bert.encoder.layer.4.attention.output.LayerNorm.bias', 'bert.encoder.layer.7.attention.self.key.bias', 'bert.encoder.layer.4.attention.self.query.weight', 'bert.encoder.layer.4.intermediate.dense.weight', 'bert.encoder.layer.4.attention.self.query.bias', 'bert.encoder.layer.6.output.LayerNorm.weight', 'bert.encoder.layer.3.attention.output.dense.weight', 'bert.encoder.layer.3.intermediate.dense.weight', 'bert.encoder.layer.3.intermediate.dense.bias', 'bert.encoder.layer.4.attention.self.value.bias', 'bert.encoder.layer.9.output.dense.weight', 'bert.pooler.dense.weight', 'bert.encoder.layer.11.attention.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.self.query.weight', 'bert.encoder.layer.5.attention.output.dense.weight', 'bert.encoder.layer.10.attention.self.key.weight', 'bert.encoder.layer.11.output.LayerNorm.weight', 'bert.encoder.layer.9.attention.self.query.bias', 'bert.encoder.layer.6.attention.self.value.bias', 'bert.encoder.layer.8.attention.self.value.bias', 'bert.encoder.layer.7.intermediate.dense.bias', 'bert.encoder.layer.10.output.dense.bias', 'bert.encoder.layer.5.output.LayerNorm.weight', 'bert.encoder.layer.8.attention.output.dense.bias', 'bert.encoder.layer.10.intermediate.dense.bias', 'bert.encoder.layer.7.output.dense.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.weight', 'bert.encoder.layer.6.attention.self.query.weight', 'bert.encoder.layer.6.attention.self.key.bias', 'bert.encoder.layer.3.attention.self.query.weight', 'bert.encoder.layer.11.output.dense.weight', 'bert.encoder.layer.9.attention.self.key.bias', 'bert.encoder.layer.2.attention.output.dense.bias', 'bert.encoder.layer.9.attention.output.dense.weight', 'bert.encoder.layer.2.attention.output.LayerNorm.weight', 'bert.encoder.layer.5.attention.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.self.key.weight', 'bert.encoder.layer.4.output.dense.weight', 'bert.encoder.layer.3.attention.self.key.bias', 'bert.encoder.layer.5.output.dense.bias', 'bert.encoder.layer.3.attention.self.value.bias', 'bert.encoder.layer.9.attention.self.key.weight', 'bert.encoder.layer.3.attention.output.LayerNorm.bias', 'bert.encoder.layer.4.intermediate.dense.bias', 'bert.encoder.layer.3.output.LayerNorm.weight', 'bert.encoder.layer.8.attention.self.query.weight', 'bert.encoder.layer.2.output.LayerNorm.weight', 'bert.encoder.layer.10.intermediate.dense.weight', 'bert.encoder.layer.4.output.LayerNorm.bias', 'bert.encoder.layer.10.attention.self.query.bias', 'bert.encoder.layer.11.attention.output.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'bert.encoder.layer.2.attention.output.dense.weight', 'bert.encoder.layer.6.intermediate.dense.bias', 'bert.encoder.layer.7.attention.output.LayerNorm.bias', 'bert.encoder.layer.2.output.dense.bias', 'bert.encoder.layer.5.attention.self.key.bias', 'bert.encoder.layer.9.output.dense.bias', 'bert.encoder.layer.2.attention.self.query.weight', 'bert.encoder.layer.5.output.dense.weight', 'bert.encoder.layer.5.attention.self.value.weight', 'bert.encoder.layer.3.output.LayerNorm.bias', 'bert.encoder.layer.11.output.LayerNorm.bias', 'bert.encoder.layer.7.attention.self.query.bias', 'bert.encoder.layer.6.output.LayerNorm.bias', 'bert.encoder.layer.9.attention.output.LayerNorm.bias', 'bert.encoder.layer.3.output.dense.weight', 'bert.encoder.layer.7.attention.self.value.weight', 'bert.encoder.layer.8.output.dense.bias', 'bert.encoder.layer.5.attention.self.query.weight', 'bert.encoder.layer.5.output.LayerNorm.bias', 'bert.encoder.layer.2.attention.self.value.weight', 'bert.encoder.layer.5.attention.self.key.weight', 'bert.encoder.layer.6.attention.self.key.weight', 'bert.encoder.layer.11.intermediate.dense.bias', 'bert.encoder.layer.6.intermediate.dense.weight', 'bert.encoder.layer.10.attention.self.query.weight', 'bert.encoder.layer.10.output.LayerNorm.weight', 'bert.encoder.layer.3.output.dense.bias', 'bert.encoder.layer.6.attention.output.LayerNorm.weight', 'bert.encoder.layer.10.attention.output.dense.bias', 'bert.encoder.layer.9.attention.output.LayerNorm.weight', 'bert.encoder.layer.11.attention.output.dense.bias', 'bert.encoder.layer.4.attention.self.key.weight', 'bert.embeddings.token_type_embeddings.weight', 'bert.encoder.layer.7.attention.self.query.weight', 'bert.encoder.layer.8.output.dense.weight', 'bert.encoder.layer.5.attention.self.value.bias', 'bert.encoder.layer.2.attention.self.value.bias', 'bert.encoder.layer.9.attention.self.value.bias', 'bert.encoder.layer.10.attention.output.LayerNorm.weight', 'bert.encoder.layer.2.attention.self.query.bias', 'bert.encoder.layer.7.attention.output.dense.weight', 'bert.encoder.layer.8.attention.self.key.bias', 'bert.encoder.layer.8.intermediate.dense.weight', 'bert.encoder.layer.8.attention.self.key.weight', 'bert.encoder.layer.9.attention.self.value.weight']\n",
16
+ "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
17
+ "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
18
+ "Some weights of BertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.self.query.weight', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.self.key.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.1.crossattention.self.key.bias', 'bert.encoder.layer.1.crossattention.self.query.bias', 'bert.encoder.layer.1.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight']\n",
19
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
20
+ ]
21
+ },
22
+ {
23
+ "name": "stdout",
24
+ "output_type": "stream",
25
+ "text": [
26
+ "/encoder/layer/0/crossattention/self/query is tied\n",
27
+ "/encoder/layer/0/crossattention/self/key is tied\n",
28
+ "/encoder/layer/0/crossattention/self/value is tied\n",
29
+ "/encoder/layer/0/crossattention/output/dense is tied\n",
30
+ "/encoder/layer/0/crossattention/output/LayerNorm is tied\n",
31
+ "/encoder/layer/0/intermediate/dense is tied\n",
32
+ "/encoder/layer/0/output/dense is tied\n",
33
+ "/encoder/layer/0/output/LayerNorm is tied\n",
34
+ "/encoder/layer/1/crossattention/self/query is tied\n",
35
+ "/encoder/layer/1/crossattention/self/key is tied\n",
36
+ "/encoder/layer/1/crossattention/self/value is tied\n",
37
+ "/encoder/layer/1/crossattention/output/dense is tied\n",
38
+ "/encoder/layer/1/crossattention/output/LayerNorm is tied\n",
39
+ "/encoder/layer/1/intermediate/dense is tied\n",
40
+ "/encoder/layer/1/output/dense is tied\n",
41
+ "/encoder/layer/1/output/LayerNorm is tied\n",
42
+ "--------------\n",
43
+ "/home/notebook/code/personal/S9049611/BLIP/output/blip_tagtotext_14m/blip_tagtotext_encoderdiv_tar_random_swin/caption_coco_finetune_tagparse_tagfinetune_threshold075_bceloss_tagsingle_5e6_epoch19_negative_1_05_pos_1_10/checkpoint_05.pth\n",
44
+ "--------------\n",
45
+ "load checkpoint from /home/notebook/code/personal/S9049611/BLIP/output/blip_tagtotext_14m/blip_tagtotext_encoderdiv_tar_random_swin/caption_coco_finetune_tagparse_tagfinetune_threshold075_bceloss_tagsingle_5e6_epoch19_negative_1_05_pos_1_10/checkpoint_05.pth\n",
46
+ "vit: swin_b\n",
47
+ "msg_v2 _IncompatibleKeys(missing_keys=['visual_encoder.layers.0.blocks.0.attn.relative_position_index', 'visual_encoder.layers.0.blocks.1.attn_mask', 'visual_encoder.layers.0.blocks.1.attn.relative_position_index', 'visual_encoder.layers.1.blocks.0.attn.relative_position_index', 'visual_encoder.layers.1.blocks.1.attn_mask', 'visual_encoder.layers.1.blocks.1.attn.relative_position_index', 'visual_encoder.layers.2.blocks.0.attn.relative_position_index', 'visual_encoder.layers.2.blocks.1.attn_mask', 'visual_encoder.layers.2.blocks.1.attn.relative_position_index', 'visual_encoder.layers.2.blocks.2.attn.relative_position_index', 'visual_encoder.layers.2.blocks.3.attn_mask', 'visual_encoder.layers.2.blocks.3.attn.relative_position_index', 'visual_encoder.layers.2.blocks.4.attn.relative_position_index', 'visual_encoder.layers.2.blocks.5.attn_mask', 'visual_encoder.layers.2.blocks.5.attn.relative_position_index', 'visual_encoder.layers.2.blocks.6.attn.relative_position_index', 'visual_encoder.layers.2.blocks.7.attn_mask', 'visual_encoder.layers.2.blocks.7.attn.relative_position_index', 'visual_encoder.layers.2.blocks.8.attn.relative_position_index', 'visual_encoder.layers.2.blocks.9.attn_mask', 'visual_encoder.layers.2.blocks.9.attn.relative_position_index', 'visual_encoder.layers.2.blocks.10.attn.relative_position_index', 'visual_encoder.layers.2.blocks.11.attn_mask', 'visual_encoder.layers.2.blocks.11.attn.relative_position_index', 'visual_encoder.layers.2.blocks.12.attn.relative_position_index', 'visual_encoder.layers.2.blocks.13.attn_mask', 'visual_encoder.layers.2.blocks.13.attn.relative_position_index', 'visual_encoder.layers.2.blocks.14.attn.relative_position_index', 'visual_encoder.layers.2.blocks.15.attn_mask', 'visual_encoder.layers.2.blocks.15.attn.relative_position_index', 'visual_encoder.layers.2.blocks.16.attn.relative_position_index', 'visual_encoder.layers.2.blocks.17.attn_mask', 'visual_encoder.layers.2.blocks.17.attn.relative_position_index', 'visual_encoder.layers.3.blocks.0.attn.relative_position_index', 'visual_encoder.layers.3.blocks.1.attn.relative_position_index'], unexpected_keys=[])\n"
48
+ ]
49
+ }
50
+ ],
51
+ "source": [
52
+ "from PIL import Image\n",
53
+ "import requests\n",
54
+ "import torch\n",
55
+ "from torchvision import transforms\n",
56
+ "from torchvision.transforms.functional import InterpolationMode\n",
57
+ "import ruamel_yaml as yaml\n",
58
+ "from models.tag2text import tag2text_caption\n",
59
+ "\n",
60
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
61
+ "\n",
62
+ "\n",
63
+ "\n",
64
+ "import gradio as gr\n",
65
+ "\n",
66
+ "image_size = 384\n",
67
+ "\n",
68
+ "\n",
69
+ "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
70
+ " std=[0.229, 0.224, 0.225])\n",
71
+ "transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.ToTensor(),normalize])\n",
72
+ "\n",
73
+ "\n",
74
+ "\n",
75
+ "#######Swin Version\n",
76
+ "pretrained = '/home/notebook/code/personal/S9049611/BLIP/output/blip_tagtotext_14m/blip_tagtotext_encoderdiv_tar_random_swin/caption_coco_finetune_tagparse_tagfinetune_threshold075_bceloss_tagsingle_5e6_epoch19_negative_1_05_pos_1_10/checkpoint_05.pth'\n",
77
+ "\n",
78
+ "config_file = 'configs/tag2text_caption.yaml'\n",
79
+ "config = yaml.load(open(config_file, 'r'), Loader=yaml.Loader)\n",
80
+ "\n",
81
+ "\n",
82
+ "model = tag2text_caption(pretrained=pretrained, image_size=image_size, vit=config['vit'], \n",
83
+ " vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],\n",
84
+ " prompt=config['prompt'],config=config,threshold = 0.75 )\n",
85
+ "\n",
86
+ "model.eval()\n",
87
+ "model = model.to(device)\n",
88
+ "\n",
89
+ "\n"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 6,
95
+ "id": "9772dc6f-680d-45a7-b39c-23770eb5258e",
96
+ "metadata": {},
97
+ "outputs": [
98
+ {
99
+ "name": "stdout",
100
+ "output_type": "stream",
101
+ "text": [
102
+ "Running on local URL: http://127.0.0.1:7864\n",
103
+ "Running on public URL: https://a10a3bf9-64b6-49d4.gradio.live\n",
104
+ "\n",
105
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces\n"
106
+ ]
107
+ },
108
+ {
109
+ "data": {
110
+ "text/html": [
111
+ "<div><iframe src=\"https://a10a3bf9-64b6-49d4.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
112
+ ],
113
+ "text/plain": [
114
+ "<IPython.core.display.HTML object>"
115
+ ]
116
+ },
117
+ "metadata": {},
118
+ "output_type": "display_data"
119
+ },
120
+ {
121
+ "data": {
122
+ "text/plain": []
123
+ },
124
+ "execution_count": 6,
125
+ "metadata": {},
126
+ "output_type": "execute_result"
127
+ }
128
+ ],
129
+ "source": [
130
+ "\n",
131
+ "def inference(raw_image, input_tag):\n",
132
+ " raw_image = raw_image.resize((image_size, image_size))\n",
133
+ " # print(type(raw_image))\n",
134
+ " image = transform(raw_image).unsqueeze(0).to(device) \n",
135
+ " model.threshold = 0.69\n",
136
+ " if input_tag == '' or input_tag == 'none' or input_tag == 'None':\n",
137
+ " input_tag_list = None\n",
138
+ " else:\n",
139
+ " input_tag_list = []\n",
140
+ " input_tag_list.append(input_tag.replace(',',' | '))\n",
141
+ " # print(input_tag_list)\n",
142
+ " with torch.no_grad():\n",
143
+ "\n",
144
+ "\n",
145
+ " caption, tag_predict = model.generate(image,tag_input = input_tag_list, return_tag_predict = True)\n",
146
+ " if input_tag_list == None:\n",
147
+ " tag_1 = tag_predict\n",
148
+ " tag_2 = ['none']\n",
149
+ " else:\n",
150
+ " _, tag_1 = model.generate(image,tag_input = None, return_tag_predict = True)\n",
151
+ " tag_2 = tag_predict\n",
152
+ "\n",
153
+ "\n",
154
+ " return tag_1[0],tag_2[0],caption[0]\n",
155
+ "\n",
156
+ " # return 'caption: '+caption[0], tag_predict[0]\n",
157
+ "\n",
158
+ "\n",
159
+ " \n",
160
+ "# inputs = [gr.inputs.Image(type='pil'),gr.inputs.Radio(choices=['Image Captioning'], type=\"value\", default=\"Image Captioning\", label=\"Task\"),gr.inputs.Textbox(lines=2, label=\"User Identified Tags (Optional, Enter with commas)\"),gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type=\"value\", default=\"Beam search\", label=\"Caption Decoding Strategy\")]\n",
161
+ "inputs = [gr.inputs.Image(type='pil'),gr.inputs.Textbox(lines=2, label=\"User Specified Tags (Optional, Enter with commas)\")]\n",
162
+ "\n",
163
+ "# outputs = gr.outputs.Textbox(label=\"Output\")\n",
164
+ "# outputs = [gr.outputs.Textbox(label=\"Image Caption\"),gr.outputs.Textbox(label=\"Identified Tags\")]\n",
165
+ "outputs = [gr.outputs.Textbox(label=\"Model Identified Tags\"),gr.outputs.Textbox(label=\"User Specified Tags\"), gr.outputs.Textbox(label=\"Image Caption\") ]\n",
166
+ "\n",
167
+ "title = \"Tag2Text\"\n",
168
+ "description = \"Welcome to Tag2Text demo! (Supported by Fudan University, OPPO Research Institute, International Digital Economy Academy) <br/> Upload your image to get the tags and caption of the image. Optional: You can also input specified tags to get the corresponding caption.\"\n",
169
+ "\n",
170
+ "article = \"<p style='text-align: center'><a href='' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a> | <a href='' target='_blank'>Github Repo</a></p>\"\n",
171
+ "\n",
172
+ "demo = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['images/COCO_val2014_000000483108.jpg',\"none\"],\n",
173
+ " ['images/COCO_val2014_000000483108.jpg',\"electric cable\"],\n",
174
+ " ['images/COCO_val2014_000000483108.jpg',\"track, train\"] , \n",
175
+ " ])\n",
176
+ "\n",
177
+ "\n",
178
+ "demo.launch(share=True)\n",
179
+ "# demo.launch()"
180
+ ]
181
+ },
182
+ {
183
+ "cell_type": "code",
184
+ "execution_count": null,
185
+ "id": "0da1f11b-e737-47a9-9b07-4e00c0835f63",
186
+ "metadata": {},
187
+ "outputs": [],
188
+ "source": [
189
+ "\n",
190
+ "def inference(raw_image, input_tag):\n",
191
+ " raw_image = raw_image.resize((image_size, image_size))\n",
192
+ " # print(type(raw_image))\n",
193
+ " image = transform(raw_image).unsqueeze(0).to(device) \n",
194
+ " model.threshold = 0.69\n",
195
+ " if input_tag == '' or input_tag == 'none' or input_tag == 'None':\n",
196
+ " input_tag_list = None\n",
197
+ " else:\n",
198
+ " input_tag_list = []\n",
199
+ " input_tag_list.append(input_tag.replace(',',' | '))\n",
200
+ " # print(input_tag_list)\n",
201
+ " with torch.no_grad():\n",
202
+ "\n",
203
+ "\n",
204
+ " caption, tag_predict = model.generate(image,tag_input = input_tag_list, return_tag_predict = True)\n",
205
+ " if input_tag_list == None:\n",
206
+ " tag_1 = tag_predict\n",
207
+ " tag_2 = ['none']\n",
208
+ " else:\n",
209
+ " _, tag_1 = model.generate(image,tag_input = None, return_tag_predict = True)\n",
210
+ " tag_2 = tag_predict\n",
211
+ "\n",
212
+ "\n",
213
+ " return tag_1[0],tag_2[0],caption[0]\n",
214
+ "\n",
215
+ " # return 'caption: '+caption[0], tag_predict[0]\n",
216
+ "\n",
217
+ "\n",
218
+ " \n",
219
+ "# inputs = [gr.inputs.Image(type='pil'),gr.inputs.Radio(choices=['Image Captioning'], type=\"value\", default=\"Image Captioning\", label=\"Task\"),gr.inputs.Textbox(lines=2, label=\"User Identified Tags (Optional, Enter with commas)\"),gr.inputs.Radio(choices=['Beam search','Nucleus sampling'], type=\"value\", default=\"Beam search\", label=\"Caption Decoding Strategy\")]\n",
220
+ "inputs = [gr.inputs.Image(type='pil'),gr.inputs.Textbox(lines=2, label=\"User Specified Tags (Optional, Enter with commas)\")]\n",
221
+ "\n",
222
+ "# outputs = gr.outputs.Textbox(label=\"Output\")\n",
223
+ "# outputs = [gr.outputs.Textbox(label=\"Image Caption\"),gr.outputs.Textbox(label=\"Identified Tags\")]\n",
224
+ "outputs = [gr.outputs.Textbox(label=\"Model Identified Tags\"),gr.outputs.Textbox(label=\"User Specified Tags\"), gr.outputs.Textbox(label=\"Image Caption\") ]\n",
225
+ "\n",
226
+ "title = \"Tag2Text\"\n",
227
+ "description = \"Welcome to Tag2Text demo! (Supported by Fudan University, OPPO Research Institute, International Digital Economy Academy) <br/> Upload your image to get the tags and caption of the image. Optional: You can also input specified tags to get the corresponding caption.\"\n",
228
+ "\n",
229
+ "article = \"<p style='text-align: center'><a href='' target='_blank'>Tag2Text: Guiding Language-Image Model via Image Tagging</a> | <a href='' target='_blank'>Github Repo</a></p>\"\n",
230
+ "\n",
231
+ "demo = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['images/COCO_val2014_000000551338.jpg',\"none\"], \n",
232
+ " ['images/COCO_val2014_000000551338.jpg',\"fence, sky\"],\n",
233
+ " # ['images/COCO_val2014_000000551338.jpg',\"grass\"],\n",
234
+ " ['images/COCO_val2014_000000483108.jpg',\"none\"],\n",
235
+ " ['images/COCO_val2014_000000483108.jpg',\"electric cable\"],\n",
236
+ " # ['images/COCO_val2014_000000483108.jpg',\"sky, train\"],\n",
237
+ " ['images/COCO_val2014_000000483108.jpg',\"track, train\"] , \n",
238
+ " ['images/COCO_val2014_000000483108.jpg',\"grass\"] \n",
239
+ " ])\n",
240
+ "\n",
241
+ "\n",
242
+ "demo.launch(share=True)\n",
243
+ "# demo.launch()"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": null,
249
+ "id": "73a4bb88-4200-4853-b1ba-34f0d4b6dc34",
250
+ "metadata": {},
251
+ "outputs": [],
252
+ "source": []
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": null,
257
+ "id": "3340a61f-c6bc-4ead-87ea-b26aa97b7a68",
258
+ "metadata": {},
259
+ "outputs": [],
260
+ "source": []
261
+ },
262
+ {
263
+ "cell_type": "code",
264
+ "execution_count": null,
265
+ "id": "d49e3de4-c3f7-4835-90eb-d0d013fc0ffb",
266
+ "metadata": {},
267
+ "outputs": [],
268
+ "source": []
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": null,
273
+ "id": "205e0317-1701-4afd-8d67-bedb6959f350",
274
+ "metadata": {},
275
+ "outputs": [],
276
+ "source": []
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": null,
281
+ "id": "bf5301a5-80c5-4e44-835e-0160a97fef66",
282
+ "metadata": {},
283
+ "outputs": [],
284
+ "source": []
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "execution_count": null,
289
+ "id": "f63d7a06-7625-4e1c-855d-177971217a0d",
290
+ "metadata": {},
291
+ "outputs": [],
292
+ "source": []
293
+ },
294
+ {
295
+ "cell_type": "code",
296
+ "execution_count": null,
297
+ "id": "c929e566-1a6e-4280-96eb-c434ef9a35d0",
298
+ "metadata": {},
299
+ "outputs": [],
300
+ "source": []
301
+ }
302
+ ],
303
+ "metadata": {
304
+ "kernelspec": {
305
+ "display_name": "Python 3 (ipykernel)",
306
+ "language": "python",
307
+ "name": "python3"
308
+ },
309
+ "language_info": {
310
+ "codemirror_mode": {
311
+ "name": "ipython",
312
+ "version": 3
313
+ },
314
+ "file_extension": ".py",
315
+ "mimetype": "text/x-python",
316
+ "name": "python",
317
+ "nbconvert_exporter": "python",
318
+ "pygments_lexer": "ipython3",
319
+ "version": "3.7.12"
320
+ }
321
+ },
322
+ "nbformat": 4,
323
+ "nbformat_minor": 5
324
+ }
images/COCO_val2014_000000483108.jpg ADDED
images/COCO_val2014_000000551338.jpg ADDED
models/__pycache__/med.cpython-37.pyc ADDED
Binary file (29.2 kB). View file
 
models/__pycache__/swin_transformer.cpython-37.pyc ADDED
Binary file (21.6 kB). View file
 
models/__pycache__/tag2text.cpython-37.pyc ADDED
Binary file (11.9 kB). View file
 
models/__pycache__/vit.cpython-37.pyc ADDED
Binary file (12.3 kB). View file
 
models/med.py ADDED
@@ -0,0 +1,1031 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ '''
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class BertEmbeddings_nopos(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
58
+ # self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
59
+
60
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
61
+ # any TensorFlow checkpoint file
62
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
63
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
64
+
65
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
66
+ # self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
67
+ # self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
68
+
69
+ self.config = config
70
+
71
+ def forward(
72
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
73
+ ):
74
+ if input_ids is not None:
75
+ input_shape = input_ids.size()
76
+ else:
77
+ input_shape = inputs_embeds.size()[:-1]
78
+
79
+ seq_length = input_shape[1]
80
+
81
+ # if position_ids is None:
82
+ # position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
83
+
84
+ if inputs_embeds is None:
85
+ inputs_embeds = self.word_embeddings(input_ids)
86
+
87
+ embeddings = inputs_embeds
88
+
89
+ # if self.position_embedding_type == "absolute":
90
+ # position_embeddings = self.position_embeddings(position_ids)
91
+ # # print('add position_embeddings!!!!')
92
+ # embeddings += position_embeddings
93
+ embeddings = self.LayerNorm(embeddings)
94
+ embeddings = self.dropout(embeddings)
95
+ return embeddings
96
+
97
+
98
+
99
+
100
+ class BertEmbeddings(nn.Module):
101
+ """Construct the embeddings from word and position embeddings."""
102
+
103
+ def __init__(self, config):
104
+ super().__init__()
105
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
106
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
107
+
108
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
109
+ # any TensorFlow checkpoint file
110
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
111
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
112
+
113
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
114
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
115
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
116
+
117
+ self.config = config
118
+
119
+ def forward(
120
+ self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
121
+ ):
122
+ if input_ids is not None:
123
+ input_shape = input_ids.size()
124
+ else:
125
+ input_shape = inputs_embeds.size()[:-1]
126
+
127
+ seq_length = input_shape[1]
128
+
129
+ if position_ids is None:
130
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
131
+
132
+ if inputs_embeds is None:
133
+ inputs_embeds = self.word_embeddings(input_ids)
134
+
135
+ embeddings = inputs_embeds
136
+
137
+ if self.position_embedding_type == "absolute":
138
+ position_embeddings = self.position_embeddings(position_ids)
139
+ # print('add position_embeddings!!!!')
140
+ embeddings += position_embeddings
141
+ embeddings = self.LayerNorm(embeddings)
142
+ embeddings = self.dropout(embeddings)
143
+ return embeddings
144
+
145
+
146
+ class BertSelfAttention(nn.Module):
147
+ def __init__(self, config, is_cross_attention):
148
+ super().__init__()
149
+ self.config = config
150
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
151
+ raise ValueError(
152
+ "The hidden size (%d) is not a multiple of the number of attention "
153
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
154
+ )
155
+
156
+ self.num_attention_heads = config.num_attention_heads
157
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
158
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
159
+
160
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
161
+ if is_cross_attention:
162
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
163
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
164
+ else:
165
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
166
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
167
+
168
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
169
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
170
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
171
+ self.max_position_embeddings = config.max_position_embeddings
172
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
173
+ self.save_attention = False
174
+
175
+ def save_attn_gradients(self, attn_gradients):
176
+ self.attn_gradients = attn_gradients
177
+
178
+ def get_attn_gradients(self):
179
+ return self.attn_gradients
180
+
181
+ def save_attention_map(self, attention_map):
182
+ self.attention_map = attention_map
183
+
184
+ def get_attention_map(self):
185
+ return self.attention_map
186
+
187
+ def transpose_for_scores(self, x):
188
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
189
+ x = x.view(*new_x_shape)
190
+ return x.permute(0, 2, 1, 3)
191
+
192
+ def forward(
193
+ self,
194
+ hidden_states,
195
+ attention_mask=None,
196
+ head_mask=None,
197
+ encoder_hidden_states=None,
198
+ encoder_attention_mask=None,
199
+ past_key_value=None,
200
+ output_attentions=False,
201
+ ):
202
+ mixed_query_layer = self.query(hidden_states)
203
+
204
+ # If this is instantiated as a cross-attention module, the keys
205
+ # and values come from an encoder; the attention mask needs to be
206
+ # such that the encoder's padding tokens are not attended to.
207
+ is_cross_attention = encoder_hidden_states is not None
208
+
209
+ if is_cross_attention:
210
+ # print(self.key.weight.shape)
211
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
212
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
213
+ attention_mask = encoder_attention_mask
214
+ elif past_key_value is not None:
215
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
216
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
217
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
218
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
219
+ else:
220
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
221
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
222
+
223
+ query_layer = self.transpose_for_scores(mixed_query_layer)
224
+
225
+ past_key_value = (key_layer, value_layer)
226
+
227
+ # Take the dot product between "query" and "key" to get the raw attention scores.
228
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
229
+
230
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
231
+ seq_length = hidden_states.size()[1]
232
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
233
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
234
+ distance = position_ids_l - position_ids_r
235
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
236
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
237
+
238
+ if self.position_embedding_type == "relative_key":
239
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
240
+ attention_scores = attention_scores + relative_position_scores
241
+ elif self.position_embedding_type == "relative_key_query":
242
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
243
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
244
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
245
+
246
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
247
+ if attention_mask is not None:
248
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
249
+ attention_scores = attention_scores + attention_mask
250
+
251
+ # Normalize the attention scores to probabilities.
252
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
253
+
254
+ if is_cross_attention and self.save_attention:
255
+ self.save_attention_map(attention_probs)
256
+ attention_probs.register_hook(self.save_attn_gradients)
257
+
258
+ # This is actually dropping out entire tokens to attend to, which might
259
+ # seem a bit unusual, but is taken from the original Transformer paper.
260
+ attention_probs_dropped = self.dropout(attention_probs)
261
+
262
+ # Mask heads if we want to
263
+ if head_mask is not None:
264
+ attention_probs_dropped = attention_probs_dropped * head_mask
265
+
266
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
267
+
268
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
269
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
270
+ context_layer = context_layer.view(*new_context_layer_shape)
271
+
272
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
273
+
274
+ outputs = outputs + (past_key_value,)
275
+ return outputs
276
+
277
+
278
+ class BertSelfOutput(nn.Module):
279
+ def __init__(self, config):
280
+ super().__init__()
281
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
282
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
283
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
284
+
285
+ def forward(self, hidden_states, input_tensor):
286
+ hidden_states = self.dense(hidden_states)
287
+ hidden_states = self.dropout(hidden_states)
288
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
289
+ return hidden_states
290
+
291
+
292
+ class BertAttention(nn.Module):
293
+ def __init__(self, config, is_cross_attention=False):
294
+ super().__init__()
295
+ self.self = BertSelfAttention(config, is_cross_attention)
296
+ self.output = BertSelfOutput(config)
297
+ self.pruned_heads = set()
298
+
299
+ def prune_heads(self, heads):
300
+ if len(heads) == 0:
301
+ return
302
+ heads, index = find_pruneable_heads_and_indices(
303
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
304
+ )
305
+
306
+ # Prune linear layers
307
+ self.self.query = prune_linear_layer(self.self.query, index)
308
+ self.self.key = prune_linear_layer(self.self.key, index)
309
+ self.self.value = prune_linear_layer(self.self.value, index)
310
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
311
+
312
+ # Update hyper params and store pruned heads
313
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
314
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
315
+ self.pruned_heads = self.pruned_heads.union(heads)
316
+
317
+ def forward(
318
+ self,
319
+ hidden_states,
320
+ attention_mask=None,
321
+ head_mask=None,
322
+ encoder_hidden_states=None,
323
+ encoder_attention_mask=None,
324
+ past_key_value=None,
325
+ output_attentions=False,
326
+ ):
327
+ self_outputs = self.self(
328
+ hidden_states,
329
+ attention_mask,
330
+ head_mask,
331
+ encoder_hidden_states,
332
+ encoder_attention_mask,
333
+ past_key_value,
334
+ output_attentions,
335
+ )
336
+ attention_output = self.output(self_outputs[0], hidden_states)
337
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
338
+ return outputs
339
+
340
+
341
+ class BertIntermediate(nn.Module):
342
+ def __init__(self, config):
343
+ super().__init__()
344
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
345
+ if isinstance(config.hidden_act, str):
346
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
347
+ else:
348
+ self.intermediate_act_fn = config.hidden_act
349
+
350
+ def forward(self, hidden_states):
351
+ hidden_states = self.dense(hidden_states)
352
+ hidden_states = self.intermediate_act_fn(hidden_states)
353
+ return hidden_states
354
+
355
+
356
+ class BertOutput(nn.Module):
357
+ def __init__(self, config):
358
+ super().__init__()
359
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
360
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
361
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
362
+
363
+ def forward(self, hidden_states, input_tensor):
364
+ hidden_states = self.dense(hidden_states)
365
+ hidden_states = self.dropout(hidden_states)
366
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
367
+ return hidden_states
368
+
369
+
370
+ class BertLayer(nn.Module):
371
+ def __init__(self, config, layer_num):
372
+ super().__init__()
373
+ self.config = config
374
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
375
+ self.seq_len_dim = 1
376
+ self.attention = BertAttention(config)
377
+ self.layer_num = layer_num
378
+ if self.config.add_cross_attention:
379
+ self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)
380
+ self.intermediate = BertIntermediate(config)
381
+ self.output = BertOutput(config)
382
+
383
+ def forward(
384
+ self,
385
+ hidden_states,
386
+ attention_mask=None,
387
+ head_mask=None,
388
+ encoder_hidden_states=None,
389
+ encoder_attention_mask=None,
390
+ past_key_value=None,
391
+ output_attentions=False,
392
+ mode=None,
393
+ ):
394
+
395
+ if mode == 'mlr':
396
+
397
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
398
+
399
+ # print('attention_output.shape',attention_output.shape)
400
+ # print('encoder_hidden_states.shape',encoder_hidden_states.shape)
401
+ cross_attention_outputs = self.crossattention(
402
+ hidden_states,
403
+ attention_mask,
404
+ head_mask,
405
+ encoder_hidden_states,
406
+ encoder_attention_mask,
407
+ output_attentions=output_attentions,
408
+ )
409
+ attention_output = cross_attention_outputs[0]
410
+ outputs = cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
411
+
412
+ present_key_value = cross_attention_outputs[-1]
413
+
414
+ else:
415
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
416
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
417
+ self_attention_outputs = self.attention(
418
+ hidden_states,
419
+ attention_mask,
420
+ head_mask,
421
+ output_attentions=output_attentions,
422
+ past_key_value=self_attn_past_key_value,
423
+ )
424
+ attention_output = self_attention_outputs[0]
425
+
426
+ outputs = self_attention_outputs[1:-1]
427
+ present_key_value = self_attention_outputs[-1]
428
+
429
+ if mode=='multimodal':
430
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
431
+
432
+ cross_attention_outputs = self.crossattention(
433
+ attention_output,
434
+ attention_mask,
435
+ head_mask,
436
+ encoder_hidden_states,
437
+ encoder_attention_mask,
438
+ output_attentions=output_attentions,
439
+ )
440
+ attention_output = cross_attention_outputs[0]
441
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
442
+ layer_output = apply_chunking_to_forward(
443
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
444
+ )
445
+ outputs = (layer_output,) + outputs
446
+
447
+ outputs = outputs + (present_key_value,)
448
+
449
+ return outputs
450
+
451
+ def feed_forward_chunk(self, attention_output):
452
+ intermediate_output = self.intermediate(attention_output)
453
+ layer_output = self.output(intermediate_output, attention_output)
454
+ return layer_output
455
+
456
+
457
+ class BertEncoder(nn.Module):
458
+ def __init__(self, config):
459
+ super().__init__()
460
+ self.config = config
461
+ self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)])
462
+ self.gradient_checkpointing = False
463
+
464
+ def forward(
465
+ self,
466
+ hidden_states,
467
+ attention_mask=None,
468
+ head_mask=None,
469
+ encoder_hidden_states=None,
470
+ encoder_attention_mask=None,
471
+ past_key_values=None,
472
+ use_cache=None,
473
+ output_attentions=False,
474
+ output_hidden_states=False,
475
+ return_dict=True,
476
+ mode='multimodal',
477
+ ):
478
+ all_hidden_states = () if output_hidden_states else None
479
+ all_self_attentions = () if output_attentions else None
480
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
481
+
482
+ next_decoder_cache = () if use_cache else None
483
+
484
+ for i in range(self.config.num_hidden_layers):
485
+ layer_module = self.layer[i]
486
+ if output_hidden_states:
487
+ all_hidden_states = all_hidden_states + (hidden_states,)
488
+
489
+ layer_head_mask = head_mask[i] if head_mask is not None else None
490
+ past_key_value = past_key_values[i] if past_key_values is not None else None
491
+
492
+ if self.gradient_checkpointing and self.training:
493
+
494
+ if use_cache:
495
+ logger.warn(
496
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
497
+ )
498
+ use_cache = False
499
+
500
+ def create_custom_forward(module):
501
+ def custom_forward(*inputs):
502
+ return module(*inputs, past_key_value, output_attentions)
503
+
504
+ return custom_forward
505
+
506
+ layer_outputs = torch.utils.checkpoint.checkpoint(
507
+ create_custom_forward(layer_module),
508
+ hidden_states,
509
+ attention_mask,
510
+ layer_head_mask,
511
+ encoder_hidden_states,
512
+ encoder_attention_mask,
513
+ mode=mode,
514
+ )
515
+ else:
516
+ layer_outputs = layer_module(
517
+ hidden_states,
518
+ attention_mask,
519
+ layer_head_mask,
520
+ encoder_hidden_states,
521
+ encoder_attention_mask,
522
+ past_key_value,
523
+ output_attentions,
524
+ mode=mode,
525
+ )
526
+
527
+ hidden_states = layer_outputs[0]
528
+ if use_cache:
529
+ next_decoder_cache += (layer_outputs[-1],)
530
+ if output_attentions:
531
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
532
+
533
+ if output_hidden_states:
534
+ all_hidden_states = all_hidden_states + (hidden_states,)
535
+
536
+ if not return_dict:
537
+ return tuple(
538
+ v
539
+ for v in [
540
+ hidden_states,
541
+ next_decoder_cache,
542
+ all_hidden_states,
543
+ all_self_attentions,
544
+ all_cross_attentions,
545
+ ]
546
+ if v is not None
547
+ )
548
+ return BaseModelOutputWithPastAndCrossAttentions(
549
+ last_hidden_state=hidden_states,
550
+ past_key_values=next_decoder_cache,
551
+ hidden_states=all_hidden_states,
552
+ attentions=all_self_attentions,
553
+ cross_attentions=all_cross_attentions,
554
+ )
555
+
556
+
557
+ class BertPooler(nn.Module):
558
+ def __init__(self, config):
559
+ super().__init__()
560
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
561
+ self.activation = nn.Tanh()
562
+
563
+ def forward(self, hidden_states):
564
+ # We "pool" the model by simply taking the hidden state corresponding
565
+ # to the first token.
566
+ first_token_tensor = hidden_states[:, 0]
567
+ pooled_output = self.dense(first_token_tensor)
568
+ pooled_output = self.activation(pooled_output)
569
+ return pooled_output
570
+
571
+
572
+ class BertPredictionHeadTransform(nn.Module):
573
+ def __init__(self, config):
574
+ super().__init__()
575
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
576
+ if isinstance(config.hidden_act, str):
577
+ self.transform_act_fn = ACT2FN[config.hidden_act]
578
+ else:
579
+ self.transform_act_fn = config.hidden_act
580
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
581
+
582
+ def forward(self, hidden_states):
583
+ hidden_states = self.dense(hidden_states)
584
+ hidden_states = self.transform_act_fn(hidden_states)
585
+ hidden_states = self.LayerNorm(hidden_states)
586
+ return hidden_states
587
+
588
+
589
+ class BertLMPredictionHead(nn.Module):
590
+ def __init__(self, config):
591
+ super().__init__()
592
+ self.transform = BertPredictionHeadTransform(config)
593
+
594
+ # The output weights are the same as the input embeddings, but there is
595
+ # an output-only bias for each token.
596
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
597
+
598
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
599
+
600
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
601
+ self.decoder.bias = self.bias
602
+
603
+ def forward(self, hidden_states):
604
+ hidden_states = self.transform(hidden_states)
605
+ hidden_states = self.decoder(hidden_states)
606
+ return hidden_states
607
+
608
+
609
+ class BertOnlyMLMHead(nn.Module):
610
+ def __init__(self, config):
611
+ super().__init__()
612
+ self.predictions = BertLMPredictionHead(config)
613
+
614
+ def forward(self, sequence_output):
615
+ prediction_scores = self.predictions(sequence_output)
616
+ return prediction_scores
617
+
618
+
619
+ class BertPreTrainedModel(PreTrainedModel):
620
+ """
621
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
622
+ models.
623
+ """
624
+
625
+ config_class = BertConfig
626
+ base_model_prefix = "bert"
627
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
628
+
629
+ def _init_weights(self, module):
630
+ """ Initialize the weights """
631
+ if isinstance(module, (nn.Linear, nn.Embedding)):
632
+ # Slightly different from the TF version which uses truncated_normal for initialization
633
+ # cf https://github.com/pytorch/pytorch/pull/5617
634
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
635
+ elif isinstance(module, nn.LayerNorm):
636
+ module.bias.data.zero_()
637
+ module.weight.data.fill_(1.0)
638
+ if isinstance(module, nn.Linear) and module.bias is not None:
639
+ module.bias.data.zero_()
640
+
641
+
642
+ class BertModel(BertPreTrainedModel):
643
+ """
644
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
645
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
646
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
647
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
648
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
649
+ input to the forward pass.
650
+ """
651
+
652
+ def __init__(self, config, add_pooling_layer=True):
653
+ super().__init__(config)
654
+ self.config = config
655
+
656
+ self.embeddings = BertEmbeddings(config)
657
+
658
+ self.encoder = BertEncoder(config)
659
+
660
+ self.pooler = BertPooler(config) if add_pooling_layer else None
661
+
662
+ self.init_weights()
663
+
664
+
665
+ def get_input_embeddings(self):
666
+ return self.embeddings.word_embeddings
667
+
668
+ def set_input_embeddings(self, value):
669
+ self.embeddings.word_embeddings = value
670
+
671
+ def _prune_heads(self, heads_to_prune):
672
+ """
673
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
674
+ class PreTrainedModel
675
+ """
676
+ for layer, heads in heads_to_prune.items():
677
+ self.encoder.layer[layer].attention.prune_heads(heads)
678
+
679
+
680
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
681
+ """
682
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
683
+
684
+ Arguments:
685
+ attention_mask (:obj:`torch.Tensor`):
686
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
687
+ input_shape (:obj:`Tuple[int]`):
688
+ The shape of the input to the model.
689
+ device: (:obj:`torch.device`):
690
+ The device of the input to the model.
691
+
692
+ Returns:
693
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
694
+ """
695
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
696
+ # ourselves in which case we just need to make it broadcastable to all heads.
697
+ if attention_mask.dim() == 3:
698
+ extended_attention_mask = attention_mask[:, None, :, :]
699
+ elif attention_mask.dim() == 2:
700
+ # Provided a padding mask of dimensions [batch_size, seq_length]
701
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
702
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
703
+ if is_decoder:
704
+ batch_size, seq_length = input_shape
705
+
706
+ seq_ids = torch.arange(seq_length, device=device)
707
+ causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
708
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
709
+ # causal and attention masks must have same type with pytorch version < 1.3
710
+ causal_mask = causal_mask.to(attention_mask.dtype)
711
+
712
+ if causal_mask.shape[1] < attention_mask.shape[1]:
713
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
714
+ causal_mask = torch.cat(
715
+ [
716
+ torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
717
+ causal_mask,
718
+ ],
719
+ axis=-1,
720
+ )
721
+
722
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
723
+ else:
724
+ extended_attention_mask = attention_mask[:, None, None, :]
725
+ else:
726
+ raise ValueError(
727
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
728
+ input_shape, attention_mask.shape
729
+ )
730
+ )
731
+
732
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
733
+ # masked positions, this operation will create a tensor which is 0.0 for
734
+ # positions we want to attend and -10000.0 for masked positions.
735
+ # Since we are adding it to the raw scores before the softmax, this is
736
+ # effectively the same as removing these entirely.
737
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
738
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
739
+ return extended_attention_mask
740
+
741
+ def forward(
742
+ self,
743
+ input_ids=None,
744
+ attention_mask=None,
745
+ position_ids=None,
746
+ head_mask=None,
747
+ inputs_embeds=None,
748
+ encoder_embeds=None,
749
+ encoder_hidden_states=None,
750
+ encoder_attention_mask=None,
751
+ past_key_values=None,
752
+ use_cache=None,
753
+ output_attentions=None,
754
+ output_hidden_states=None,
755
+ return_dict=None,
756
+ is_decoder=False,
757
+ mode='multimodal',
758
+ ):
759
+ r"""
760
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
761
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
762
+ the model is configured as a decoder.
763
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
764
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
765
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
766
+ - 1 for tokens that are **not masked**,
767
+ - 0 for tokens that are **masked**.
768
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
769
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
770
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
771
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
772
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
773
+ use_cache (:obj:`bool`, `optional`):
774
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
775
+ decoding (see :obj:`past_key_values`).
776
+ """
777
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
778
+ output_hidden_states = (
779
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
780
+ )
781
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
782
+
783
+ if is_decoder:
784
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
785
+ else:
786
+ use_cache = False
787
+
788
+ if input_ids is not None and inputs_embeds is not None:
789
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
790
+ elif input_ids is not None:
791
+ input_shape = input_ids.size()
792
+ batch_size, seq_length = input_shape
793
+ device = input_ids.device
794
+ elif inputs_embeds is not None:
795
+ input_shape = inputs_embeds.size()[:-1]
796
+ batch_size, seq_length = input_shape
797
+ device = inputs_embeds.device
798
+ elif encoder_embeds is not None:
799
+ input_shape = encoder_embeds.size()[:-1]
800
+ batch_size, seq_length = input_shape
801
+ device = encoder_embeds.device
802
+ else:
803
+ raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
804
+
805
+ # past_key_values_length
806
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
807
+
808
+ if attention_mask is None:
809
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
810
+
811
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
812
+ # ourselves in which case we just need to make it broadcastable to all heads.
813
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
814
+ device, is_decoder)
815
+
816
+ # If a 2D or 3D attention mask is provided for the cross-attention
817
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
818
+ if encoder_hidden_states is not None:
819
+ if type(encoder_hidden_states) == list:
820
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
821
+ else:
822
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
823
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
824
+
825
+ if type(encoder_attention_mask) == list:
826
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
827
+ elif encoder_attention_mask is None:
828
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
829
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
830
+ else:
831
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
832
+ else:
833
+ encoder_extended_attention_mask = None
834
+
835
+ # Prepare head mask if needed
836
+ # 1.0 in head_mask indicate we keep the head
837
+ # attention_probs has shape bsz x n_heads x N x N
838
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
839
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
840
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
841
+
842
+ if encoder_embeds is None:
843
+ embedding_output = self.embeddings(
844
+ input_ids=input_ids,
845
+ position_ids=position_ids,
846
+ inputs_embeds=inputs_embeds,
847
+ past_key_values_length=past_key_values_length,
848
+ )
849
+ else:
850
+ embedding_output = encoder_embeds
851
+
852
+ encoder_outputs = self.encoder(
853
+ embedding_output,
854
+ attention_mask=extended_attention_mask,
855
+ head_mask=head_mask,
856
+ encoder_hidden_states=encoder_hidden_states,
857
+ encoder_attention_mask=encoder_extended_attention_mask,
858
+ past_key_values=past_key_values,
859
+ use_cache=use_cache,
860
+ output_attentions=output_attentions,
861
+ output_hidden_states=output_hidden_states,
862
+ return_dict=return_dict,
863
+ mode=mode,
864
+ )
865
+ sequence_output = encoder_outputs[0]
866
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
867
+
868
+ if not return_dict:
869
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
870
+
871
+ return BaseModelOutputWithPoolingAndCrossAttentions(
872
+ last_hidden_state=sequence_output,
873
+ pooler_output=pooled_output,
874
+ past_key_values=encoder_outputs.past_key_values,
875
+ hidden_states=encoder_outputs.hidden_states,
876
+ attentions=encoder_outputs.attentions,
877
+ cross_attentions=encoder_outputs.cross_attentions,
878
+ )
879
+
880
+
881
+ class BertLMHeadModel(BertPreTrainedModel):
882
+
883
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
884
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
885
+
886
+ def __init__(self, config):
887
+ super().__init__(config)
888
+
889
+ self.bert = BertModel(config, add_pooling_layer=False)
890
+ self.cls = BertOnlyMLMHead(config)
891
+
892
+ self.init_weights()
893
+
894
+ def get_output_embeddings(self):
895
+ return self.cls.predictions.decoder
896
+
897
+ def set_output_embeddings(self, new_embeddings):
898
+ self.cls.predictions.decoder = new_embeddings
899
+
900
+ def forward(
901
+ self,
902
+ input_ids=None,
903
+ attention_mask=None,
904
+ position_ids=None,
905
+ head_mask=None,
906
+ inputs_embeds=None,
907
+ encoder_hidden_states=None,
908
+ encoder_attention_mask=None,
909
+ labels=None,
910
+ past_key_values=None,
911
+ use_cache=None,
912
+ output_attentions=None,
913
+ output_hidden_states=None,
914
+ return_dict=None,
915
+ return_logits=False,
916
+ is_decoder=True,
917
+ reduction='mean',
918
+ mode='multimodal',
919
+ ):
920
+ r"""
921
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
922
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
923
+ the model is configured as a decoder.
924
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
925
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
926
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
927
+ - 1 for tokens that are **not masked**,
928
+ - 0 for tokens that are **masked**.
929
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
930
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
931
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
932
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
933
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
934
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
935
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
936
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
937
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
938
+ use_cache (:obj:`bool`, `optional`):
939
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
940
+ decoding (see :obj:`past_key_values`).
941
+ Returns:
942
+ Example::
943
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
944
+ >>> import torch
945
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
946
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
947
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
948
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
949
+ >>> outputs = model(**inputs)
950
+ >>> prediction_logits = outputs.logits
951
+ """
952
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
953
+ if labels is not None:
954
+ use_cache = False
955
+
956
+ outputs = self.bert(
957
+ input_ids,
958
+ attention_mask=attention_mask,
959
+ position_ids=position_ids,
960
+ head_mask=head_mask,
961
+ inputs_embeds=inputs_embeds,
962
+ encoder_hidden_states=encoder_hidden_states,
963
+ encoder_attention_mask=encoder_attention_mask,
964
+ past_key_values=past_key_values,
965
+ use_cache=use_cache,
966
+ output_attentions=output_attentions,
967
+ output_hidden_states=output_hidden_states,
968
+ return_dict=return_dict,
969
+ is_decoder=is_decoder,
970
+ mode=mode,
971
+ )
972
+
973
+ sequence_output = outputs[0]
974
+ prediction_scores = self.cls(sequence_output)
975
+ # sequence_output.shape torch.Size([85, 30, 768])
976
+ # prediction_scores.shape torch.Size([85, 30, 30524])
977
+ # labels.shape torch.Size([85, 30])
978
+
979
+
980
+ if return_logits:
981
+ return prediction_scores[:, :-1, :].contiguous()
982
+
983
+ lm_loss = None
984
+ if labels is not None:
985
+ # we are doing next-token prediction; shift prediction scores and input ids by one
986
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
987
+ labels = labels[:, 1:].contiguous()
988
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
989
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
990
+ if reduction=='none':
991
+ lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1)
992
+
993
+ if not return_dict:
994
+ output = (prediction_scores,) + outputs[2:]
995
+ return ((lm_loss,) + output) if lm_loss is not None else output
996
+
997
+ return CausalLMOutputWithCrossAttentions(
998
+ loss=lm_loss,
999
+ logits=prediction_scores,
1000
+ past_key_values=outputs.past_key_values,
1001
+ hidden_states=outputs.hidden_states,
1002
+ attentions=outputs.attentions,
1003
+ cross_attentions=outputs.cross_attentions,
1004
+ )
1005
+
1006
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
1007
+ input_shape = input_ids.shape
1008
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1009
+ if attention_mask is None:
1010
+ attention_mask = input_ids.new_ones(input_shape)
1011
+
1012
+ # cut decoder_input_ids if past is used
1013
+ if past is not None:
1014
+ input_ids = input_ids[:, -1:]
1015
+
1016
+ return {
1017
+ "input_ids": input_ids,
1018
+ "attention_mask": attention_mask,
1019
+ "past_key_values": past,
1020
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1021
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1022
+ "is_decoder": True,
1023
+ }
1024
+
1025
+ def _reorder_cache(self, past, beam_idx):
1026
+ reordered_past = ()
1027
+ for layer_past in past:
1028
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1029
+ return reordered_past
1030
+
1031
+
models/swin_transformer.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Swin Transformer
3
+ # Copyright (c) 2021 Microsoft
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # Written by Ze Liu
6
+ # --------------------------------------------------------
7
+
8
+ import numpy as np
9
+ from scipy import interpolate
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.utils.checkpoint as checkpoint
14
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
19
+ super().__init__()
20
+ out_features = out_features or in_features
21
+ hidden_features = hidden_features or in_features
22
+ self.fc1 = nn.Linear(in_features, hidden_features)
23
+ self.act = act_layer()
24
+ self.fc2 = nn.Linear(hidden_features, out_features)
25
+ self.drop = nn.Dropout(drop)
26
+
27
+ def forward(self, x):
28
+ x = self.fc1(x)
29
+ x = self.act(x)
30
+ x = self.drop(x)
31
+ x = self.fc2(x)
32
+ x = self.drop(x)
33
+ return x
34
+
35
+
36
+ def window_partition(x, window_size):
37
+ """
38
+ Args:
39
+ x: (B, H, W, C)
40
+ window_size (int): window size
41
+
42
+ Returns:
43
+ windows: (num_windows*B, window_size, window_size, C)
44
+ """
45
+ B, H, W, C = x.shape
46
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
47
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
48
+ return windows
49
+
50
+
51
+ def window_reverse(windows, window_size, H, W):
52
+ """
53
+ Args:
54
+ windows: (num_windows*B, window_size, window_size, C)
55
+ window_size (int): Window size
56
+ H (int): Height of image
57
+ W (int): Width of image
58
+
59
+ Returns:
60
+ x: (B, H, W, C)
61
+ """
62
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
63
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
64
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
65
+ return x
66
+
67
+
68
+ class WindowAttention(nn.Module):
69
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
70
+ It supports both of shifted and non-shifted window.
71
+
72
+ Args:
73
+ dim (int): Number of input channels.
74
+ window_size (tuple[int]): The height and width of the window.
75
+ num_heads (int): Number of attention heads.
76
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
77
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
78
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
79
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
80
+ """
81
+
82
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
83
+
84
+ super().__init__()
85
+ self.dim = dim
86
+ self.window_size = window_size # Wh, Ww
87
+ self.num_heads = num_heads
88
+ head_dim = dim // num_heads
89
+ self.scale = qk_scale or head_dim ** -0.5
90
+
91
+ # define a parameter table of relative position bias
92
+ self.relative_position_bias_table = nn.Parameter(
93
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
94
+
95
+ # get pair-wise relative position index for each token inside the window
96
+ coords_h = torch.arange(self.window_size[0])
97
+ coords_w = torch.arange(self.window_size[1])
98
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
99
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
100
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
101
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
102
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
103
+ relative_coords[:, :, 1] += self.window_size[1] - 1
104
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
105
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
106
+ self.register_buffer("relative_position_index", relative_position_index)
107
+
108
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
109
+ self.attn_drop = nn.Dropout(attn_drop)
110
+ self.proj = nn.Linear(dim, dim)
111
+ self.proj_drop = nn.Dropout(proj_drop)
112
+
113
+ trunc_normal_(self.relative_position_bias_table, std=.02)
114
+ self.softmax = nn.Softmax(dim=-1)
115
+
116
+ def forward(self, x, mask=None):
117
+ """
118
+ Args:
119
+ x: input features with shape of (num_windows*B, N, C)
120
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
121
+ """
122
+ B_, N, C = x.shape
123
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
124
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
125
+
126
+ q = q * self.scale
127
+ attn = (q @ k.transpose(-2, -1))
128
+
129
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
130
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
131
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
132
+ attn = attn + relative_position_bias.unsqueeze(0)
133
+
134
+ if mask is not None:
135
+ nW = mask.shape[0]
136
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
137
+ attn = attn.view(-1, self.num_heads, N, N)
138
+ attn = self.softmax(attn)
139
+ else:
140
+ attn = self.softmax(attn)
141
+
142
+ attn = self.attn_drop(attn)
143
+
144
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
145
+ x = self.proj(x)
146
+ x = self.proj_drop(x)
147
+ return x
148
+
149
+ def extra_repr(self) -> str:
150
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
151
+
152
+ def flops(self, N):
153
+ # calculate flops for 1 window with token length of N
154
+ flops = 0
155
+ # qkv = self.qkv(x)
156
+ flops += N * self.dim * 3 * self.dim
157
+ # attn = (q @ k.transpose(-2, -1))
158
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
159
+ # x = (attn @ v)
160
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
161
+ # x = self.proj(x)
162
+ flops += N * self.dim * self.dim
163
+ return flops
164
+
165
+
166
+ class SwinTransformerBlock(nn.Module):
167
+ r""" Swin Transformer Block.
168
+
169
+ Args:
170
+ dim (int): Number of input channels.
171
+ input_resolution (tuple[int]): Input resulotion.
172
+ num_heads (int): Number of attention heads.
173
+ window_size (int): Window size.
174
+ shift_size (int): Shift size for SW-MSA.
175
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
176
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
177
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
178
+ drop (float, optional): Dropout rate. Default: 0.0
179
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
180
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
181
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
182
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
183
+ """
184
+
185
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
186
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
187
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
188
+ super().__init__()
189
+ self.dim = dim
190
+ self.input_resolution = input_resolution
191
+ self.num_heads = num_heads
192
+ self.window_size = window_size
193
+ self.shift_size = shift_size
194
+ self.mlp_ratio = mlp_ratio
195
+ if min(self.input_resolution) <= self.window_size:
196
+ # if window size is larger than input resolution, we don't partition windows
197
+ self.shift_size = 0
198
+ self.window_size = min(self.input_resolution)
199
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
200
+
201
+ self.norm1 = norm_layer(dim)
202
+ self.attn = WindowAttention(
203
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
204
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
205
+
206
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
207
+ self.norm2 = norm_layer(dim)
208
+ mlp_hidden_dim = int(dim * mlp_ratio)
209
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
210
+
211
+ if self.shift_size > 0:
212
+ # calculate attention mask for SW-MSA
213
+ H, W = self.input_resolution
214
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
215
+ h_slices = (slice(0, -self.window_size),
216
+ slice(-self.window_size, -self.shift_size),
217
+ slice(-self.shift_size, None))
218
+ w_slices = (slice(0, -self.window_size),
219
+ slice(-self.window_size, -self.shift_size),
220
+ slice(-self.shift_size, None))
221
+ cnt = 0
222
+ for h in h_slices:
223
+ for w in w_slices:
224
+ img_mask[:, h, w, :] = cnt
225
+ cnt += 1
226
+
227
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
228
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
229
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
230
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
231
+ else:
232
+ attn_mask = None
233
+
234
+ self.register_buffer("attn_mask", attn_mask)
235
+
236
+ def forward(self, x):
237
+ H, W = self.input_resolution
238
+ B, L, C = x.shape
239
+ assert L == H * W, "input feature has wrong size"
240
+
241
+ shortcut = x
242
+ x = self.norm1(x)
243
+ x = x.view(B, H, W, C)
244
+
245
+ # cyclic shift
246
+ if self.shift_size > 0:
247
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
248
+ else:
249
+ shifted_x = x
250
+
251
+ # partition windows
252
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
253
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
254
+
255
+ # W-MSA/SW-MSA
256
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
257
+
258
+ # merge windows
259
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
260
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
261
+
262
+ # reverse cyclic shift
263
+ if self.shift_size > 0:
264
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
265
+ else:
266
+ x = shifted_x
267
+ x = x.view(B, H * W, C)
268
+
269
+ # FFN
270
+ x = shortcut + self.drop_path(x)
271
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
272
+
273
+ return x
274
+
275
+ def extra_repr(self) -> str:
276
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
277
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
278
+
279
+ def flops(self):
280
+ flops = 0
281
+ H, W = self.input_resolution
282
+ # norm1
283
+ flops += self.dim * H * W
284
+ # W-MSA/SW-MSA
285
+ nW = H * W / self.window_size / self.window_size
286
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
287
+ # mlp
288
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
289
+ # norm2
290
+ flops += self.dim * H * W
291
+ return flops
292
+
293
+
294
+ class PatchMerging(nn.Module):
295
+ r""" Patch Merging Layer.
296
+
297
+ Args:
298
+ input_resolution (tuple[int]): Resolution of input feature.
299
+ dim (int): Number of input channels.
300
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
301
+ """
302
+
303
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
304
+ super().__init__()
305
+ self.input_resolution = input_resolution
306
+ self.dim = dim
307
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
308
+ self.norm = norm_layer(4 * dim)
309
+
310
+ def forward(self, x):
311
+ """
312
+ x: B, H*W, C
313
+ """
314
+ H, W = self.input_resolution
315
+ B, L, C = x.shape
316
+ assert L == H * W, "input feature has wrong size"
317
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
318
+
319
+ x = x.view(B, H, W, C)
320
+
321
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
322
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
323
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
324
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
325
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
326
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
327
+
328
+ x = self.norm(x)
329
+ x = self.reduction(x)
330
+
331
+ return x
332
+
333
+ def extra_repr(self) -> str:
334
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
335
+
336
+ def flops(self):
337
+ H, W = self.input_resolution
338
+ flops = H * W * self.dim
339
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
340
+ return flops
341
+
342
+
343
+ class BasicLayer(nn.Module):
344
+ """ A basic Swin Transformer layer for one stage.
345
+
346
+ Args:
347
+ dim (int): Number of input channels.
348
+ input_resolution (tuple[int]): Input resolution.
349
+ depth (int): Number of blocks.
350
+ num_heads (int): Number of attention heads.
351
+ window_size (int): Local window size.
352
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
353
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
354
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
355
+ drop (float, optional): Dropout rate. Default: 0.0
356
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
357
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
358
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
359
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
360
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
361
+ """
362
+
363
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
364
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
365
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
366
+
367
+ super().__init__()
368
+ self.dim = dim
369
+ self.input_resolution = input_resolution
370
+ self.depth = depth
371
+ self.use_checkpoint = use_checkpoint
372
+
373
+ # build blocks
374
+ self.blocks = nn.ModuleList([
375
+ SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
376
+ num_heads=num_heads, window_size=window_size,
377
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
378
+ mlp_ratio=mlp_ratio,
379
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
380
+ drop=drop, attn_drop=attn_drop,
381
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
382
+ norm_layer=norm_layer)
383
+ for i in range(depth)])
384
+
385
+ # patch merging layer
386
+ if downsample is not None:
387
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
388
+ else:
389
+ self.downsample = None
390
+
391
+ def forward(self, x):
392
+ for blk in self.blocks:
393
+ if self.use_checkpoint:
394
+ x = checkpoint.checkpoint(blk, x)
395
+ else:
396
+ x = blk(x)
397
+ if self.downsample is not None:
398
+ x = self.downsample(x)
399
+ return x
400
+
401
+ def extra_repr(self) -> str:
402
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
403
+
404
+ def flops(self):
405
+ flops = 0
406
+ for blk in self.blocks:
407
+ flops += blk.flops()
408
+ if self.downsample is not None:
409
+ flops += self.downsample.flops()
410
+ return flops
411
+
412
+
413
+ class PatchEmbed(nn.Module):
414
+ r""" Image to Patch Embedding
415
+
416
+ Args:
417
+ img_size (int): Image size. Default: 224.
418
+ patch_size (int): Patch token size. Default: 4.
419
+ in_chans (int): Number of input image channels. Default: 3.
420
+ embed_dim (int): Number of linear projection output channels. Default: 96.
421
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
422
+ """
423
+
424
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
425
+ super().__init__()
426
+ img_size = to_2tuple(img_size)
427
+ patch_size = to_2tuple(patch_size)
428
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
429
+ self.img_size = img_size
430
+ self.patch_size = patch_size
431
+ self.patches_resolution = patches_resolution
432
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
433
+
434
+ self.in_chans = in_chans
435
+ self.embed_dim = embed_dim
436
+
437
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
438
+ if norm_layer is not None:
439
+ self.norm = norm_layer(embed_dim)
440
+ else:
441
+ self.norm = None
442
+
443
+ def forward(self, x):
444
+ B, C, H, W = x.shape
445
+ # FIXME look at relaxing size constraints
446
+ assert H == self.img_size[0] and W == self.img_size[1], \
447
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
448
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
449
+ if self.norm is not None:
450
+ x = self.norm(x)
451
+ return x
452
+
453
+ def flops(self):
454
+ Ho, Wo = self.patches_resolution
455
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
456
+ if self.norm is not None:
457
+ flops += Ho * Wo * self.embed_dim
458
+ return flops
459
+
460
+
461
+ class SwinTransformer(nn.Module):
462
+ r""" Swin Transformer
463
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
464
+ https://arxiv.org/pdf/2103.14030
465
+
466
+ Args:
467
+ img_size (int | tuple(int)): Input image size. Default 224
468
+ patch_size (int | tuple(int)): Patch size. Default: 4
469
+ in_chans (int): Number of input image channels. Default: 3
470
+ num_classes (int): Number of classes for classification head. Default: 1000
471
+ embed_dim (int): Patch embedding dimension. Default: 96
472
+ depths (tuple(int)): Depth of each Swin Transformer layer.
473
+ num_heads (tuple(int)): Number of attention heads in different layers.
474
+ window_size (int): Window size. Default: 7
475
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
476
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
477
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
478
+ drop_rate (float): Dropout rate. Default: 0
479
+ attn_drop_rate (float): Attention dropout rate. Default: 0
480
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
481
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
482
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
483
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
484
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
485
+ """
486
+
487
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
488
+ embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
489
+ window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
490
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
491
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
492
+ use_checkpoint=False, **kwargs):
493
+ super().__init__()
494
+
495
+ self.num_classes = num_classes
496
+ self.num_layers = len(depths)
497
+ self.embed_dim = embed_dim
498
+ self.ape = ape
499
+ self.patch_norm = patch_norm
500
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
501
+ self.mlp_ratio = mlp_ratio
502
+
503
+ # split image into non-overlapping patches
504
+ self.patch_embed = PatchEmbed(
505
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
506
+ norm_layer=norm_layer if self.patch_norm else None)
507
+ num_patches = self.patch_embed.num_patches
508
+ patches_resolution = self.patch_embed.patches_resolution
509
+ self.patches_resolution = patches_resolution
510
+
511
+ # absolute position embedding
512
+ if self.ape:
513
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
514
+ trunc_normal_(self.absolute_pos_embed, std=.02)
515
+
516
+ self.pos_drop = nn.Dropout(p=drop_rate)
517
+
518
+ # stochastic depth
519
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
520
+
521
+ # build layers
522
+ self.layers = nn.ModuleList()
523
+ for i_layer in range(self.num_layers):
524
+ layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
525
+ input_resolution=(patches_resolution[0] // (2 ** i_layer),
526
+ patches_resolution[1] // (2 ** i_layer)),
527
+ depth=depths[i_layer],
528
+ num_heads=num_heads[i_layer],
529
+ window_size=window_size,
530
+ mlp_ratio=self.mlp_ratio,
531
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
532
+ drop=drop_rate, attn_drop=attn_drop_rate,
533
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
534
+ norm_layer=norm_layer,
535
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
536
+ use_checkpoint=use_checkpoint)
537
+ self.layers.append(layer)
538
+
539
+ self.norm = norm_layer(self.num_features)
540
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
541
+ # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
542
+
543
+ self.apply(self._init_weights)
544
+
545
+ def _init_weights(self, m):
546
+ if isinstance(m, nn.Linear):
547
+ trunc_normal_(m.weight, std=.02)
548
+ if isinstance(m, nn.Linear) and m.bias is not None:
549
+ nn.init.constant_(m.bias, 0)
550
+ elif isinstance(m, nn.LayerNorm):
551
+ nn.init.constant_(m.bias, 0)
552
+ nn.init.constant_(m.weight, 1.0)
553
+
554
+ @torch.jit.ignore
555
+ def no_weight_decay(self):
556
+ return {'absolute_pos_embed'}
557
+
558
+ @torch.jit.ignore
559
+ def no_weight_decay_keywords(self):
560
+ return {'relative_position_bias_table'}
561
+
562
+ def forward(self, x, idx_to_group_img=None, image_atts=None, **kwargs):
563
+ x = self.patch_embed(x)
564
+ if self.ape:
565
+ x = x + self.absolute_pos_embed
566
+ x = self.pos_drop(x)
567
+
568
+ for layer in self.layers:
569
+ x = layer(x)
570
+
571
+ x = self.norm(x) # B L C
572
+
573
+ x_cls = self.avgpool(x.transpose(1, 2)) # B C 1
574
+
575
+ if idx_to_group_img is None:
576
+ return torch.cat([x_cls.transpose(1, 2), x], dim=1)
577
+ else:
578
+ x_bs = torch.gather(x, dim=0, index=idx_to_group_img.view(-1, 1, 1).expand(-1, x.shape[1], x.shape[2]))
579
+ weights = image_atts[:, 1:].unsqueeze(2) # B L 1
580
+ x_bs_cls = torch.sum((weights * x_bs).transpose(1, 2), dim=-1, keepdim=True) # B C 1
581
+ x_bs_cls = x_bs_cls / torch.sum(weights.transpose(1, 2), dim=-1, keepdim=True) # avgpool
582
+
583
+ return torch.cat([x_bs_cls.transpose(1, 2), x_bs], dim=1), \
584
+ torch.cat([x_cls.transpose(1, 2), x], dim=1)
585
+
586
+ def flops(self):
587
+ flops = 0
588
+ flops += self.patch_embed.flops()
589
+ for i, layer in enumerate(self.layers):
590
+ flops += layer.flops()
591
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
592
+ flops += self.num_features * self.num_classes
593
+ return flops
594
+
595
+
596
+ def interpolate_relative_pos_embed(rel_pos_bias, dst_num_pos, param_name=''):
597
+ # from: https://github.com/microsoft/unilm/blob/8a0a1c1f4e7326938ea7580a00d56d7f17d65612/beit/run_class_finetuning.py#L348
598
+
599
+ # rel_pos_bias: relative_position_bias_table
600
+ src_num_pos, num_attn_heads = rel_pos_bias.size()
601
+
602
+ num_extra_tokens = 0
603
+ src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
604
+ dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
605
+ if src_size != dst_size:
606
+ print("Position interpolate %s from %dx%d to %dx%d" % (param_name, src_size, src_size, dst_size, dst_size))
607
+
608
+ # extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
609
+ # rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
610
+
611
+ def geometric_progression(a, r, n):
612
+ return a * (1.0 - r ** n) / (1.0 - r)
613
+
614
+ left, right = 1.01, 1.5
615
+ while right - left > 1e-6:
616
+ q = (left + right) / 2.0
617
+ gp = geometric_progression(1, q, src_size // 2)
618
+ if gp > dst_size // 2:
619
+ right = q
620
+ else:
621
+ left = q
622
+
623
+ # if q > 1.090307:
624
+ # q = 1.090307
625
+
626
+ dis = []
627
+ cur = 1
628
+ for i in range(src_size // 2):
629
+ dis.append(cur)
630
+ cur += q ** (i + 1)
631
+
632
+ r_ids = [-_ for _ in reversed(dis)]
633
+
634
+ x = r_ids + [0] + dis
635
+ y = r_ids + [0] + dis
636
+
637
+ t = dst_size // 2.0
638
+ dx = np.arange(-t, t + 0.1, 1.0)
639
+ dy = np.arange(-t, t + 0.1, 1.0)
640
+
641
+ # print("Original positions = %s" % str(x))
642
+ # print("Target positions = %s" % str(dx))
643
+
644
+ all_rel_pos_bias = []
645
+
646
+ for i in range(num_attn_heads):
647
+ z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
648
+ f = interpolate.interp2d(x, y, z, kind='cubic')
649
+ all_rel_pos_bias.append(
650
+ torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
651
+
652
+ rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
653
+
654
+ return rel_pos_bias
models/tag2text.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Tag2Text
3
+ * Written by Xinyu Huang
4
+ '''
5
+ import warnings
6
+ warnings.filterwarnings("ignore")
7
+
8
+ from models.vit import VisionTransformer, interpolate_pos_embed
9
+ from models.swin_transformer import SwinTransformer, interpolate_relative_pos_embed
10
+ from models.med import BertConfig, BertModel, BertLMHeadModel
11
+ from transformers import BertTokenizer
12
+
13
+ import torch
14
+ from torch import nn
15
+ import torch.nn.functional as F
16
+
17
+ import os
18
+ from urllib.parse import urlparse
19
+ from timm.models.hub import download_cached_file
20
+ from data.tag_class import tra_array
21
+ import json
22
+ import math
23
+ import numpy as np
24
+
25
+ def read_json(rpath):
26
+ with open(rpath, 'r') as f:
27
+ return json.load(f)
28
+
29
+ class Tag2Text_Caption(nn.Module):
30
+ def __init__(self,
31
+ med_config = 'configs/med_config.json',
32
+ image_size = 384,
33
+ vit = 'base',
34
+ vit_grad_ckpt = False,
35
+ vit_ckpt_layer = 0,
36
+ prompt = 'a picture of ',
37
+ config = None,
38
+ threshold = 0.2,
39
+ ):
40
+ """
41
+ Args:
42
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
43
+ image_size (int): input image size
44
+ vit (str): model size of vision transformer
45
+ """
46
+ super().__init__()
47
+
48
+ if vit=='swin_b':
49
+ if image_size == 224:
50
+ vision_config_path = 'configs/swin/config_swinB_224.json'
51
+ elif image_size == 384:
52
+ vision_config_path = 'configs/swin/config_swinB_384.json'
53
+ vision_config = read_json(vision_config_path)
54
+ assert image_size == vision_config['image_res']
55
+ # assert config['patch_size'] == 32
56
+ vision_width = vision_config['vision_width']
57
+
58
+ self.visual_encoder = SwinTransformer(img_size=vision_config['image_res'],
59
+ patch_size=4,
60
+ in_chans=3,
61
+ embed_dim=vision_config['embed_dim'],
62
+ depths=vision_config['depths'],
63
+ num_heads=vision_config['num_heads'],
64
+ window_size=vision_config['window_size'],
65
+ mlp_ratio=4.,
66
+ qkv_bias=True,
67
+ drop_rate=0.0,
68
+ drop_path_rate=0.1,
69
+ ape=False,
70
+ patch_norm=True,
71
+ use_checkpoint=False)
72
+
73
+ else:
74
+ self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
75
+
76
+
77
+ self.tokenizer = init_tokenizer()
78
+
79
+ # create the decoder
80
+ decoder_config = BertConfig.from_json_file(med_config)
81
+ decoder_config.encoder_width = 768
82
+ self.text_decoder = BertLMHeadModel(config=decoder_config)
83
+
84
+ # create encoder
85
+ encoder_config = BertConfig.from_json_file(med_config)
86
+ encoder_config.encoder_width = vision_width
87
+ self.tag_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
88
+
89
+ self.prompt = prompt
90
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
91
+
92
+ self.threshold = threshold
93
+ num_features = 768
94
+ self.num_class = config['class_num']
95
+
96
+ q2l_config = BertConfig.from_json_file('configs/q2l_config.json')
97
+ q2l_config.encoder_width = vision_width
98
+ self.vision_multi = BertModel.from_pretrained('bert-base-uncased',config=q2l_config, add_pooling_layer=False)
99
+ self.vision_multi.resize_token_embeddings(len(self.tokenizer))
100
+ self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size)
101
+ self.fc = GroupWiseLinear(self.num_class, num_features, bias=True)
102
+ self.del_selfattention()
103
+
104
+ tie_encoder_decoder_weights(self.tag_encoder,self.vision_multi,'',' ')
105
+ self.tag_array = tra_array
106
+
107
+ def del_selfattention(self):
108
+ del self.vision_multi.embeddings
109
+ for layer in self.vision_multi.encoder.layer:
110
+ del layer.attention
111
+
112
+ def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0, tag_input = None, return_tag_predict = False):
113
+ image_embeds = self.visual_encoder(image)
114
+ image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
115
+
116
+ #==============generate tag==============#
117
+ if tag_input == None:
118
+ image_spatial_embeds = image_embeds[:,1:,:]
119
+ image_cls_embeds = image_embeds[:,0,:]
120
+
121
+ bs = image_spatial_embeds.shape[0]
122
+ label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs,1,1)
123
+ mlr_tagembedding = self.vision_multi(encoder_embeds = label_embed,
124
+ encoder_hidden_states = image_embeds,
125
+ encoder_attention_mask = image_atts,
126
+ return_dict = False,
127
+ mode = 'mlr',
128
+ )
129
+
130
+ logits = self.fc(mlr_tagembedding[0])
131
+
132
+ targets = torch.where(torch.sigmoid(logits) > self.threshold , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
133
+
134
+ tag = targets.cpu().numpy()
135
+ bs = image.size(0)
136
+ tag_input = []
137
+ for b in range(bs):
138
+ index = np.argwhere(tag[b] == 1)
139
+ token = self.tag_array[index].squeeze(axis = 1)
140
+ tag_input.append(' | '.join(token))
141
+ #========================================#
142
+
143
+ if not sample:
144
+ image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
145
+ tag_input_temp = []
146
+ for tag in tag_input:
147
+ for i in range(num_beams):
148
+ tag_input_temp.append(tag)
149
+ tag_input = tag_input_temp
150
+
151
+
152
+ tag_input_tokenzier = self.tokenizer(tag_input, padding='max_length', truncation=True, max_length=40,
153
+ return_tensors="pt").to(image.device)
154
+ encoder_input_ids = tag_input_tokenzier.input_ids
155
+ encoder_input_ids[:,0] = self.tokenizer.enc_token_id
156
+
157
+ output_tagembedding = self.tag_encoder(encoder_input_ids,
158
+ attention_mask = tag_input_tokenzier.attention_mask,
159
+ encoder_hidden_states = image_embeds,
160
+ encoder_attention_mask = image_atts,
161
+ return_dict = True,
162
+ )
163
+
164
+ prompt = [self.prompt] * image.size(0)
165
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
166
+ input_ids[:,0] = self.tokenizer.bos_token_id
167
+ input_ids = input_ids[:, :-1]
168
+
169
+ if sample:
170
+ #nucleus sampling
171
+ model_kwargs = {"encoder_hidden_states": output_tagembedding.last_hidden_state, "encoder_attention_mask":None}
172
+ outputs = self.text_decoder.generate(input_ids=input_ids,
173
+ max_length=max_length,
174
+ min_length=min_length,
175
+ do_sample=True,
176
+ top_p=top_p,
177
+ num_return_sequences=1,
178
+ eos_token_id=self.tokenizer.sep_token_id,
179
+ pad_token_id=self.tokenizer.pad_token_id,
180
+ repetition_penalty=1.1,
181
+ **model_kwargs)
182
+ else:
183
+ #beam search
184
+ model_kwargs = {"encoder_hidden_states": output_tagembedding.last_hidden_state, "encoder_attention_mask":None}
185
+ outputs = self.text_decoder.generate(input_ids=input_ids,
186
+ max_length=max_length,
187
+ min_length=min_length,
188
+ num_beams=num_beams,
189
+ eos_token_id=self.tokenizer.sep_token_id,
190
+ pad_token_id=self.tokenizer.pad_token_id,
191
+ repetition_penalty=repetition_penalty,
192
+ **model_kwargs)
193
+
194
+ captions = []
195
+ for output in outputs:
196
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
197
+ captions.append(caption[len(self.prompt):])
198
+ if return_tag_predict == True:
199
+ if sample:
200
+ return captions, tag_input
201
+ else:
202
+ return captions, tag_input[0:int(len(tag_input)/num_beams)]
203
+ return captions
204
+
205
+
206
+ def tag2text_caption(pretrained='',**kwargs):
207
+ model = Tag2Text_Caption(**kwargs)
208
+ if pretrained:
209
+ if kwargs['vit'] == 'swin_b':
210
+ model,msg = load_checkpoint_swinbase(model,pretrained,kwargs)
211
+ else:
212
+ model,msg = load_checkpoint(model,pretrained)
213
+ print('vit:',kwargs['vit'])
214
+ print('msg_v2',msg)
215
+ return model
216
+
217
+
218
+ from typing import List
219
+ def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
220
+ uninitialized_encoder_weights: List[str] = []
221
+ if decoder.__class__ != encoder.__class__:
222
+ logger.info(
223
+ f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
224
+ )
225
+
226
+ def tie_encoder_to_decoder_recursively(
227
+ decoder_pointer: nn.Module,
228
+ encoder_pointer: nn.Module,
229
+ module_name: str,
230
+ uninitialized_encoder_weights: List[str],
231
+ skip_key: str,
232
+ depth=0,
233
+ ):
234
+ assert isinstance(decoder_pointer, nn.Module) and isinstance(
235
+ encoder_pointer, nn.Module
236
+ ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
237
+ if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
238
+ assert hasattr(encoder_pointer, "weight")
239
+ encoder_pointer.weight = decoder_pointer.weight
240
+ if hasattr(decoder_pointer, "bias"):
241
+ assert hasattr(encoder_pointer, "bias")
242
+ encoder_pointer.bias = decoder_pointer.bias
243
+ print(module_name+' is tied')
244
+ return
245
+
246
+ encoder_modules = encoder_pointer._modules
247
+ decoder_modules = decoder_pointer._modules
248
+ if len(decoder_modules) > 0:
249
+ assert (
250
+ len(encoder_modules) > 0
251
+ ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
252
+
253
+ all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
254
+ encoder_layer_pos = 0
255
+ for name, module in decoder_modules.items():
256
+ if name.isdigit():
257
+ encoder_name = str(int(name) + encoder_layer_pos)
258
+ decoder_name = name
259
+ if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
260
+ encoder_modules
261
+ ) != len(decoder_modules):
262
+ # this can happen if the name corresponds to the position in a list module list of layers
263
+ # in this case the decoder has added a cross-attention that the encoder does not have
264
+ # thus skip this step and subtract one layer pos from encoder
265
+ encoder_layer_pos -= 1
266
+ continue
267
+ elif name not in encoder_modules:
268
+ continue
269
+ elif depth > 500:
270
+ raise ValueError(
271
+ "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
272
+ )
273
+ else:
274
+ decoder_name = encoder_name = name
275
+ tie_encoder_to_decoder_recursively(
276
+ decoder_modules[decoder_name],
277
+ encoder_modules[encoder_name],
278
+ module_name + "/" + name,
279
+ uninitialized_encoder_weights,
280
+ skip_key,
281
+ depth=depth + 1,
282
+ )
283
+ all_encoder_weights.remove(module_name + "/" + encoder_name)
284
+
285
+ uninitialized_encoder_weights += list(all_encoder_weights)
286
+
287
+ # tie weights recursively
288
+ tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key)
289
+
290
+
291
+ class GroupWiseLinear(nn.Module):
292
+ # could be changed to:
293
+ # output = torch.einsum('ijk,zjk->ij', x, self.W)
294
+ # or output = torch.einsum('ijk,jk->ij', x, self.W[0])
295
+ def __init__(self, num_class, hidden_dim, bias=True):
296
+ super().__init__()
297
+ self.num_class = num_class
298
+ self.hidden_dim = hidden_dim
299
+ self.bias = bias
300
+
301
+ self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim))
302
+ if bias:
303
+ self.b = nn.Parameter(torch.Tensor(1, num_class))
304
+ self.reset_parameters()
305
+
306
+ def reset_parameters(self):
307
+ stdv = 1. / math.sqrt(self.W.size(2))
308
+ for i in range(self.num_class):
309
+ self.W[0][i].data.uniform_(-stdv, stdv)
310
+ if self.bias:
311
+ for i in range(self.num_class):
312
+ self.b[0][i].data.uniform_(-stdv, stdv)
313
+
314
+ def forward(self, x):
315
+ # x: B,K,d
316
+ x = (self.W * x).sum(-1)
317
+ if self.bias:
318
+ x = x + self.b
319
+ return x
320
+
321
+
322
+ def init_tokenizer():
323
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
324
+ tokenizer.add_special_tokens({'bos_token':'[DEC]'})
325
+ tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
326
+ tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
327
+ return tokenizer
328
+
329
+
330
+ def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
331
+
332
+ assert vit in ['base', 'large'], "vit parameter must be base or large"
333
+ if vit=='base':
334
+ vision_width = 768
335
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
336
+ num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
337
+ drop_path_rate=0 or drop_path_rate
338
+ )
339
+ elif vit=='large':
340
+ vision_width = 1024
341
+ visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
342
+ num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
343
+ drop_path_rate=0.1 or drop_path_rate
344
+ )
345
+ return visual_encoder, vision_width
346
+
347
+ def is_url(url_or_filename):
348
+ parsed = urlparse(url_or_filename)
349
+ return parsed.scheme in ("http", "https")
350
+
351
+ def load_checkpoint(model,url_or_filename):
352
+ if is_url(url_or_filename):
353
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
354
+ checkpoint = torch.load(cached_file, map_location='cpu')
355
+ elif os.path.isfile(url_or_filename):
356
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
357
+ else:
358
+ raise RuntimeError('checkpoint url or path is invalid')
359
+
360
+ state_dict = checkpoint['model']
361
+
362
+ state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
363
+ if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
364
+ state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
365
+ model.visual_encoder_m)
366
+ for key in model.state_dict().keys():
367
+ if key in state_dict.keys():
368
+ if state_dict[key].shape!=model.state_dict()[key].shape:
369
+ del state_dict[key]
370
+
371
+ msg = model.load_state_dict(state_dict,strict=False)
372
+ print('load checkpoint from %s'%url_or_filename)
373
+ return model,msg
374
+
375
+
376
+ def load_checkpoint_swinbase(model,url_or_filename,kwargs):
377
+ if kwargs['image_size'] == 224:
378
+ vision_config_path = 'configs/swin/config_swinB_224.json'
379
+ elif kwargs['image_size'] == 384:
380
+ vision_config_path = 'configs/swin/config_swinB_384.json'
381
+ elif kwargs['image_size'] == 480:
382
+ vision_config_path = 'configs/swin/config_swinB_480.json'
383
+ elif kwargs['image_size'] == 576:
384
+ vision_config_path = 'configs/swin/config_swinB_576.json'
385
+ elif kwargs['image_size'] == 608:
386
+ vision_config_path = 'configs/swin/config_swinB_608.json'
387
+ window_size = read_json(vision_config_path)['window_size']
388
+ print('--------------')
389
+ print(url_or_filename)
390
+ print('--------------')
391
+ if is_url(url_or_filename):
392
+ cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
393
+ checkpoint = torch.load(cached_file, map_location='cpu')
394
+ elif os.path.isfile(url_or_filename):
395
+ checkpoint = torch.load(url_or_filename, map_location='cpu')
396
+ else:
397
+ raise RuntimeError('checkpoint url or path is invalid')
398
+
399
+ state_dict = checkpoint['model']
400
+
401
+ for k in list(state_dict.keys()):
402
+ if 'relative_position_bias_table' in k:
403
+ dst_num_pos = (2 * window_size - 1) ** 2
404
+ state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k)
405
+ elif ('relative_position_index' in k) or ('attn_mask' in k):
406
+ del state_dict[k]
407
+
408
+ msg = model.load_state_dict(state_dict,strict=False)
409
+ print('load checkpoint from %s'%url_or_filename)
410
+ return model,msg
411
+
412
+
413
+
414
+
415
+
models/vit.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on timm code base
8
+ * https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ '''
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from functools import partial
15
+
16
+ from timm.models.vision_transformer import _cfg, PatchEmbed
17
+ from timm.models.registry import register_model
18
+ from timm.models.layers import trunc_normal_, DropPath
19
+ from timm.models.helpers import named_apply, adapt_input_conv
20
+
21
+ from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
22
+
23
+ class Mlp(nn.Module):
24
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
25
+ """
26
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x):
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
42
+
43
+
44
+ class Attention(nn.Module):
45
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
50
+ self.scale = qk_scale or head_dim ** -0.5
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+ self.attn_gradients = None
56
+ self.attention_map = None
57
+
58
+ def save_attn_gradients(self, attn_gradients):
59
+ self.attn_gradients = attn_gradients
60
+
61
+ def get_attn_gradients(self):
62
+ return self.attn_gradients
63
+
64
+ def save_attention_map(self, attention_map):
65
+ self.attention_map = attention_map
66
+
67
+ def get_attention_map(self):
68
+ return self.attention_map
69
+
70
+ def forward(self, x, register_hook=False):
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
73
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
74
+
75
+ attn = (q @ k.transpose(-2, -1)) * self.scale
76
+ attn = attn.softmax(dim=-1)
77
+ attn = self.attn_drop(attn)
78
+
79
+ if register_hook:
80
+ self.save_attention_map(attn)
81
+ attn.register_hook(self.save_attn_gradients)
82
+
83
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
84
+ x = self.proj(x)
85
+ x = self.proj_drop(x)
86
+ return x
87
+
88
+
89
+ class Block(nn.Module):
90
+
91
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
92
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False):
93
+ super().__init__()
94
+ self.norm1 = norm_layer(dim)
95
+ self.attn = Attention(
96
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
97
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
98
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
99
+ self.norm2 = norm_layer(dim)
100
+ mlp_hidden_dim = int(dim * mlp_ratio)
101
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
102
+
103
+ if use_grad_checkpointing:
104
+ self.attn = checkpoint_wrapper(self.attn)
105
+ self.mlp = checkpoint_wrapper(self.mlp)
106
+
107
+ def forward(self, x, register_hook=False):
108
+ x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
109
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
110
+ return x
111
+
112
+
113
+ class VisionTransformer(nn.Module):
114
+ """ Vision Transformer
115
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
116
+ https://arxiv.org/abs/2010.11929
117
+ """
118
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
119
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
120
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
121
+ use_grad_checkpointing=False, ckpt_layer=0):
122
+ """
123
+ Args:
124
+ img_size (int, tuple): input image size
125
+ patch_size (int, tuple): patch size
126
+ in_chans (int): number of input channels
127
+ num_classes (int): number of classes for classification head
128
+ embed_dim (int): embedding dimension
129
+ depth (int): depth of transformer
130
+ num_heads (int): number of attention heads
131
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
132
+ qkv_bias (bool): enable bias for qkv if True
133
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
134
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
135
+ drop_rate (float): dropout rate
136
+ attn_drop_rate (float): attention dropout rate
137
+ drop_path_rate (float): stochastic depth rate
138
+ norm_layer: (nn.Module): normalization layer
139
+ """
140
+ super().__init__()
141
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
142
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
143
+
144
+ self.patch_embed = PatchEmbed(
145
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
146
+
147
+ num_patches = self.patch_embed.num_patches
148
+
149
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
150
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
151
+ self.pos_drop = nn.Dropout(p=drop_rate)
152
+
153
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
154
+ self.blocks = nn.ModuleList([
155
+ Block(
156
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
157
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
158
+ use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer)
159
+ )
160
+ for i in range(depth)])
161
+ self.norm = norm_layer(embed_dim)
162
+
163
+ trunc_normal_(self.pos_embed, std=.02)
164
+ trunc_normal_(self.cls_token, std=.02)
165
+ self.apply(self._init_weights)
166
+
167
+ def _init_weights(self, m):
168
+ if isinstance(m, nn.Linear):
169
+ trunc_normal_(m.weight, std=.02)
170
+ if isinstance(m, nn.Linear) and m.bias is not None:
171
+ nn.init.constant_(m.bias, 0)
172
+ elif isinstance(m, nn.LayerNorm):
173
+ nn.init.constant_(m.bias, 0)
174
+ nn.init.constant_(m.weight, 1.0)
175
+
176
+ @torch.jit.ignore
177
+ def no_weight_decay(self):
178
+ return {'pos_embed', 'cls_token'}
179
+
180
+ def forward(self, x, register_blk=-1):
181
+ B = x.shape[0]
182
+ x = self.patch_embed(x)
183
+
184
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
185
+ x = torch.cat((cls_tokens, x), dim=1)
186
+
187
+ x = x + self.pos_embed[:,:x.size(1),:]
188
+ x = self.pos_drop(x)
189
+
190
+ for i,blk in enumerate(self.blocks):
191
+ x = blk(x, register_blk==i)
192
+ x = self.norm(x)
193
+
194
+ return x
195
+
196
+ @torch.jit.ignore()
197
+ def load_pretrained(self, checkpoint_path, prefix=''):
198
+ _load_weights(self, checkpoint_path, prefix)
199
+
200
+
201
+ @torch.no_grad()
202
+ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
203
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
204
+ """
205
+ import numpy as np
206
+
207
+ def _n2p(w, t=True):
208
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
209
+ w = w.flatten()
210
+ if t:
211
+ if w.ndim == 4:
212
+ w = w.transpose([3, 2, 0, 1])
213
+ elif w.ndim == 3:
214
+ w = w.transpose([2, 0, 1])
215
+ elif w.ndim == 2:
216
+ w = w.transpose([1, 0])
217
+ return torch.from_numpy(w)
218
+
219
+ w = np.load(checkpoint_path)
220
+ if not prefix and 'opt/target/embedding/kernel' in w:
221
+ prefix = 'opt/target/'
222
+
223
+ if hasattr(model.patch_embed, 'backbone'):
224
+ # hybrid
225
+ backbone = model.patch_embed.backbone
226
+ stem_only = not hasattr(backbone, 'stem')
227
+ stem = backbone if stem_only else backbone.stem
228
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
229
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
230
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
231
+ if not stem_only:
232
+ for i, stage in enumerate(backbone.stages):
233
+ for j, block in enumerate(stage.blocks):
234
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
235
+ for r in range(3):
236
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
237
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
238
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
239
+ if block.downsample is not None:
240
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
241
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
242
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
243
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
244
+ else:
245
+ embed_conv_w = adapt_input_conv(
246
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
247
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
248
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
249
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
250
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
251
+ if pos_embed_w.shape != model.pos_embed.shape:
252
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
253
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
254
+ model.pos_embed.copy_(pos_embed_w)
255
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
256
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
257
+ # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
258
+ # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
259
+ # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
260
+ # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
261
+ # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
262
+ # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
263
+ for i, block in enumerate(model.blocks.children()):
264
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
265
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
266
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
267
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
268
+ block.attn.qkv.weight.copy_(torch.cat([
269
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
270
+ block.attn.qkv.bias.copy_(torch.cat([
271
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
272
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
273
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
274
+ for r in range(2):
275
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
276
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
277
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
278
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
279
+
280
+
281
+ def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder):
282
+ # interpolate position embedding
283
+ embedding_size = pos_embed_checkpoint.shape[-1]
284
+ num_patches = visual_encoder.patch_embed.num_patches
285
+ num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches
286
+ # height (== width) for the checkpoint position embedding
287
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
288
+ # height (== width) for the new position embedding
289
+ new_size = int(num_patches ** 0.5)
290
+
291
+ if orig_size!=new_size:
292
+ # class_token and dist_token are kept unchanged
293
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
294
+ # only the position tokens are interpolated
295
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
296
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
297
+ pos_tokens = torch.nn.functional.interpolate(
298
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
299
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
300
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
301
+ print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2))
302
+
303
+ return new_pos_embed
304
+ else:
305
+ return pos_embed_checkpoint
requirements.txt CHANGED
@@ -1,7 +1,4 @@
1
  timm==0.4.12
2
- git+https://github.com/huggingface/transformers.git@main
3
  fairscale==0.4.4
4
  pycocoevalcap
5
- torch
6
- torchvision
7
- Pillow
 
1
  timm==0.4.12
2
+ transformers==4.15.0
3
  fairscale==0.4.4
4
  pycocoevalcap
 
 
 
upload.ipynb ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "ecc45cb5-15a0-424b-b97d-d29a73b2e809",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/opt/conda/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
14
+ " from .autonotebook import tqdm as notebook_tqdm\n"
15
+ ]
16
+ },
17
+ {
18
+ "ename": "ImportError",
19
+ "evalue": "cannot import name 'login' from 'huggingface_hub' (/home/oppoer/.local/lib/python3.7/site-packages/huggingface_hub/__init__.py)",
20
+ "output_type": "error",
21
+ "traceback": [
22
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
23
+ "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
24
+ "\u001b[0;32m/tmp/ipykernel_1707/2376108712.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mhuggingface_hub\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mlogin\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
25
+ "\u001b[0;31mImportError\u001b[0m: cannot import name 'login' from 'huggingface_hub' (/home/oppoer/.local/lib/python3.7/site-packages/huggingface_hub/__init__.py)"
26
+ ]
27
+ }
28
+ ],
29
+ "source": [
30
+ "from huggingface_hub import login"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "id": "1d5da498-403a-4d54-8fd2-981665980977",
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": []
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "id": "a2947119-5752-4d4f-99f0-e6d306bcf0ae",
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": []
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": null,
52
+ "id": "bba28b02-af45-4be3-b26e-7f456a48fe95",
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": []
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "id": "9f349128-a60e-4bb2-9a06-1628e42cb659",
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": []
64
+ }
65
+ ],
66
+ "metadata": {
67
+ "kernelspec": {
68
+ "display_name": "Python 3 (ipykernel)",
69
+ "language": "python",
70
+ "name": "python3"
71
+ },
72
+ "language_info": {
73
+ "codemirror_mode": {
74
+ "name": "ipython",
75
+ "version": 3
76
+ },
77
+ "file_extension": ".py",
78
+ "mimetype": "text/x-python",
79
+ "name": "python",
80
+ "nbconvert_exporter": "python",
81
+ "pygments_lexer": "ipython3",
82
+ "version": "3.7.12"
83
+ }
84
+ },
85
+ "nbformat": 4,
86
+ "nbformat_minor": 5
87
+ }