diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..c297f8131850950f27a124669b859f6813028650 Binary files /dev/null and b/.DS_Store differ diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0e40fe8f57160b43f9ea8e200b1a5d9f91f4aed9 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ + +# Default ignored files +/workspace.xml \ No newline at end of file diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..a63e87f4e1e90c96861648a16a7304d97d3c3f7b --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,12 @@ +Copyright (c) 2022, Salesforce.com, Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README 2.md b/README 2.md new file mode 100644 index 0000000000000000000000000000000000000000..ee6fa83b80516116a11f2a65d1656e708c344da3 --- /dev/null +++ b/README 2.md @@ -0,0 +1,46 @@ +--- +title: PPE_Detection +emoji: 💩 +colorFrom: pink +colorTo: indigo +sdk: gradio +app_file: app.py +pinned: false +license: other +--- + +# Configuration + +`title`: _string_ +Display title for the Space + +`emoji`: _string_ +Space emoji (emoji-only character allowed) + +`colorFrom`: _string_ +Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray) + +`colorTo`: _string_ +Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray) + +`sdk`: _string_ +Can be either `gradio`, `streamlit`, or `static` + +`sdk_version` : _string_ +Only applicable for `streamlit` SDK. +See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions. + +`app_file`: _string_ +Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code). +Path is relative to the root of the repository. + +`models`: _List[string]_ +HF model IDs (like "gpt2" or "deepset/roberta-base-squad2") used in the Space. +Will be parsed automatically from your code if not specified here. + +`datasets`: _List[string]_ +HF dataset IDs (like "common_voice" or "oscar-corpus/OSCAR-2109") used in the Space. +Will be parsed automatically from your code if not specified here. + +`pinned`: _boolean_ +Whether the Space stays on top of your list. diff --git a/README.md b/README.md index d747e21b37b5771a5558b2841e5f45942cc81e4c..faf29046687493dd3e0e2a5a8edf887ff9e45ec5 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,46 @@ --- -title: Safeworld_Captioning_Spaces -emoji: 📚 -colorFrom: indigo -colorTo: purple +title: BLIP +emoji: 🦀 +colorFrom: red +colorTo: blue sdk: gradio app_file: app.py pinned: false -license: other +license: bsd-3-clause --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference +# Configuration + +`title`: _string_ +Display title for the Space + +`emoji`: _string_ +Space emoji (emoji-only character allowed) + +`colorFrom`: _string_ +Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray) + +`colorTo`: _string_ +Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray) + +`sdk`: _string_ +Can be either `gradio`, `streamlit`, or `static` + +`sdk_version` : _string_ +Only applicable for `streamlit` SDK. +See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions. + +`app_file`: _string_ +Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code). +Path is relative to the root of the repository. + +`models`: _List[string]_ +HF model IDs (like "gpt2" or "deepset/roberta-base-squad2") used in the Space. +Will be parsed automatically from your code if not specified here. + +`datasets`: _List[string]_ +HF dataset IDs (like "common_voice" or "oscar-corpus/OSCAR-2109") used in the Space. +Will be parsed automatically from your code if not specified here. + +`pinned`: _boolean_ +Whether the Space stays on top of your list. diff --git a/__pycache__/run_code.cpython-38.pyc b/__pycache__/run_code.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee1af932f4c1eaadc45ccc7a7ca3f49b12971730 Binary files /dev/null and b/__pycache__/run_code.cpython-38.pyc differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..73bb9456f0fc6573f6009db58c392d70a51b4665 --- /dev/null +++ b/app.py @@ -0,0 +1,232 @@ +import os + +os.system( + "wget https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1920px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg -O starry.jpg") + +from PIL import Image +import requests +import torch +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# MDETR Code +import torchvision.transforms as T +import matplotlib.pyplot as plt +from collections import defaultdict +import torch.nn.functional as F +import numpy as np +from skimage.measure import find_contours + +from matplotlib import patches, lines +from matplotlib.patches import Polygon +import gradio as gr + +torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2014/03/04/15/10/elephants-279505_1280.jpg', + 'elephant.jpg') + +model2, postprocessor = torch.hub.load('ashkamath/mdetr:main', 'mdetr_efficientnetB5', pretrained=True, + return_postprocessor=True) +model2 = model2.cpu() +model2.eval() + +torch.set_grad_enabled(False); +# standard PyTorch mean-std input image normalization +transform = T.Compose([ + T.Resize(800), + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) +]) + + +# for output bounding box post-processing +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=1) + + +def rescale_bboxes(out_bbox, size): + img_w, img_h = size + b = box_cxcywh_to_xyxy(out_bbox) + b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) + return b + + +# colors for visualization +COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], + [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] + + +def apply_mask(image, mask, color, alpha=0.5): + """Apply the given mask to the image. + """ + for c in range(3): + image[:, :, c] = np.where(mask == 1, + image[:, :, c] * + (1 - alpha) + alpha * color[c] * 255, + image[:, :, c]) + return image + + +def plot_results(pil_img, scores, boxes, labels, masks=None): + plt.figure(figsize=(16, 10)) + np_image = np.array(pil_img) + ax = plt.gca() + colors = COLORS * 100 + if masks is None: + masks = [None for _ in range(len(scores))] + assert len(scores) == len(boxes) == len(labels) == len(masks) + for s, (xmin, ymin, xmax, ymax), l, mask, c in zip(scores, boxes.tolist(), labels, masks, colors): + ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, + fill=False, color=c, linewidth=3)) + text = f'{l}: {s:0.2f}' + ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='white', alpha=0.8)) + + if mask is None: + continue + np_image = apply_mask(np_image, mask, c) + + padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8) + padded_mask[1:-1, 1:-1] = mask + contours = find_contours(padded_mask, 0.5) + for verts in contours: + # Subtract the padding and flip (y, x) to (x, y) + verts = np.fliplr(verts) - 1 + p = Polygon(verts, facecolor="none", edgecolor=c) + ax.add_patch(p) + + plt.imshow(np_image) + plt.axis('off') + plt.savefig('foo.png', bbox_inches='tight') + return 'foo.png' + + +def add_res(results, ax, color='green'): + # for tt in results.values(): + if True: + bboxes = results['boxes'] + labels = results['labels'] + scores = results['scores'] + # keep = scores >= 0.0 + # bboxes = bboxes[keep].tolist() + # labels = labels[keep].tolist() + # scores = scores[keep].tolist() + # print(torchvision.ops.box_iou(tt['boxes'].cpu().detach(), torch.as_tensor([[xmin, ymin, xmax, ymax]]))) + + colors = ['purple', 'yellow', 'red', 'green', 'orange', 'pink'] + + for i, (b, ll, ss) in enumerate(zip(bboxes, labels, scores)): + ax.add_patch(plt.Rectangle((b[0], b[1]), b[2] - b[0], b[3] - b[1], fill=False, color=colors[i], linewidth=3)) + cls_name = ll if isinstance(ll, str) else CLASSES[ll] + text = f'{cls_name}: {ss:.2f}' + print(text) + ax.text(b[0], b[1], text, fontsize=15, bbox=dict(facecolor='white', alpha=0.8)) + + +def plot_inference(im, caption, approaches): + choices = {"Worker Helmet Separately": 1, "Worker Helmet Vest": 2, "Workers only": 3} + + # mean-std normalize the input image (batch-size: 1) + img = transform(im).unsqueeze(0).cpu() + + # propagate through the model + memory_cache = model2(img, [caption], encode_and_save=True) + outputs = model2(img, [caption], encode_and_save=False, memory_cache=memory_cache) + + # keep only predictions with 0.7+ confidence + probas = 1 - outputs['pred_logits'].softmax(-1)[0, :, -1].cpu() + keep = (probas > 0.7).cpu() + + # convert boxes from [0; 1] to image scales + bboxes_scaled = rescale_bboxes(outputs['pred_boxes'].cpu()[0, keep], im.size) + + # Extract the text spans predicted by each box + positive_tokens = (outputs["pred_logits"].cpu()[0, keep].softmax(-1) > 0.1).nonzero().tolist() + predicted_spans = defaultdict(str) + for tok in positive_tokens: + item, pos = tok + if pos < 255: + span = memory_cache["tokenized"].token_to_chars(0, pos) + predicted_spans[item] += " " + caption[span.start:span.end] + + labels = [predicted_spans[k] for k in sorted(list(predicted_spans.keys()))] + caption = 'Caption: ' + caption + return (sepia_call(caption, im, plot_results(im, probas[keep], bboxes_scaled, labels), choices[approaches])) + + +# BLIP Code + + +from modelsn.blip import blip_decoder + +image_size = 384 +transform = transforms.Compose([ + transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) +]) + +model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth' + +model = blip_decoder(pretrained=model_url, image_size=384, vit='base') +model.eval() +model = model.to(device) + +from modelsn.blip_vqa import blip_vqa + +image_size_vq = 480 +transform_vq = transforms.Compose([ + transforms.Resize((image_size_vq, image_size_vq), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) +]) + +model_url_vq = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth' + +model_vq = blip_vqa(pretrained=model_url_vq, image_size=480, vit='base') +model_vq.eval() +model_vq = model_vq.to(device) + + +def inference(raw_image, approaches, question): + image = transform(raw_image).unsqueeze(0).to(device) + with torch.no_grad(): + caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5) + + return (plot_inference(raw_image, caption[0], approaches)) + # return 'caption: '+caption[0] + + +# PPE Detection code +import numpy as np +import run_code +import gradio as gr + + +def sepia_call(caption, Input_Image, MDETR_im, Approach): + pil_image = Input_Image + open_cv_image = np.asarray(pil_image) + sepia_img = run_code.run(open_cv_image, Approach) + images = sepia_img['img'] + texts = sepia_img['text'] + + return (caption, MDETR_im, images, texts) + + +inputs = [gr.inputs.Image(type='pil'), + gr.inputs.Radio(choices=["Worker Helmet Separately", "Worker Helmet Vest", "Workers only"], type="value", + default="Worker Helmet Vest", label="Model"), "textbox"] +outputs = [gr.outputs.Textbox(label="Output"), "image", "image", gr.outputs.Textbox(label="Output")] + +title = "BLIP + MDETR + PPE Detection" + +description = "Gradio demo for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation by Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." + +article = "

BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation | Github Repo

" + +gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, + examples=[['starry.jpg', "Image Captioning", "None"]]).launch(share=True, enable_queue=True, + cache_examples=False) \ No newline at end of file diff --git a/app_run.ipynb b/app_run.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..49b0a64cb2d3dbba39564fb179c607661a75c78f --- /dev/null +++ b/app_run.ipynb @@ -0,0 +1,400 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "id": "15468c81", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "--2022-02-15 18:26:17-- https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1920px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg\n", + "Resolving upload.wikimedia.org (upload.wikimedia.org)... 91.198.174.208\n", + "Connecting to upload.wikimedia.org (upload.wikimedia.org)|91.198.174.208|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 1388211 (1.3M) [image/jpeg]\n", + "Saving to: ‘starry.jpg’\n", + "\n", + " 0K .......... .......... .......... .......... .......... 3% 776K 2s\n", + " 50K .......... .......... .......... .......... .......... 7% 877K 2s\n", + " 100K .......... .......... .......... .......... .......... 11% 2.93M 1s\n", + " 150K .......... .......... .......... .......... .......... 14% 2.28M 1s\n", + " 200K .......... .......... .......... .......... .......... 18% 4.04M 1s\n", + " 250K .......... .......... .......... .......... .......... 22% 5.46M 1s\n", + " 300K .......... .......... .......... .......... .......... 25% 6.40M 1s\n", + " 350K .......... .......... .......... .......... .......... 29% 2.41M 0s\n", + " 400K .......... .......... .......... .......... .......... 33% 3.18M 0s\n", + " 450K .......... .......... .......... .......... .......... 36% 3.03M 0s\n", + " 500K .......... .......... .......... .......... .......... 40% 8.30M 0s\n", + " 550K .......... .......... .......... .......... .......... 44% 3.31M 0s\n", + " 600K .......... .......... .......... .......... .......... 47% 3.10M 0s\n", + " 650K .......... .......... .......... .......... .......... 51% 12.3M 0s\n", + " 700K .......... .......... .......... .......... .......... 55% 4.20M 0s\n", + " 750K .......... .......... .......... .......... .......... 59% 1.93M 0s\n", + " 800K .......... .......... .......... .......... .......... 62% 6.28M 0s\n", + " 850K .......... .......... .......... .......... .......... 66% 3.09M 0s\n", + " 900K .......... .......... .......... .......... .......... 70% 22.7M 0s\n", + " 950K .......... .......... .......... .......... .......... 73% 4.43M 0s\n", + " 1000K .......... .......... .......... .......... .......... 77% 4.16M 0s\n", + " 1050K .......... .......... .......... .......... .......... 81% 2.29M 0s\n", + " 1100K .......... .......... .......... .......... .......... 84% 1.81M 0s\n", + " 1150K .......... .......... .......... .......... .......... 88% 6.20M 0s\n", + " 1200K .......... .......... .......... .......... .......... 92% 2.03M 0s\n", + " 1250K .......... .......... .......... .......... .......... 95% 23.5M 0s\n", + " 1300K .......... .......... .......... .......... .......... 99% 5.04M 0s\n", + " 1350K ..... 100% 9.95M=0.5s\n", + "\n", + "2022-02-15 18:26:17 (2.89 MB/s) - ‘starry.jpg’ saved [1388211/1388211]\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "02b7655f0b2b404b952b7c152a3a1661", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0.00/262k [00:00\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(,\n", + " 'http://127.0.0.1:7862/',\n", + " 'https://13389.gradio.app')" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2022-02-15 18:27:19.011924: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", + "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" + ] + } + ], + "source": [ + "import os\n", + "os.system(\"wget https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1920px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg -O starry.jpg\")\n", + "\n", + "from PIL import Image\n", + "import requests\n", + "import torch\n", + "from torchvision import transforms\n", + "from torchvision.transforms.functional import InterpolationMode\n", + "\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "\n", + "\n", + "\n", + " \n", + "#MDETR Code \n", + "import torchvision.transforms as T\n", + "import matplotlib.pyplot as plt\n", + "from collections import defaultdict\n", + "import torch.nn.functional as F\n", + "import numpy as np\n", + "from skimage.measure import find_contours\n", + "\n", + "from matplotlib import patches, lines\n", + "from matplotlib.patches import Polygon\n", + "import gradio as gr\n", + "\n", + "torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2014/03/04/15/10/elephants-279505_1280.jpg', 'elephant.jpg')\n", + "\n", + "\n", + "model2, postprocessor = torch.hub.load('ashkamath/mdetr:main', 'mdetr_efficientnetB5', pretrained=True, return_postprocessor=True)\n", + "model2 = model2.cpu()\n", + "model2.eval()\n", + "\n", + "\n", + "\n", + "\n", + "torch.set_grad_enabled(False);\n", + "# standard PyTorch mean-std input image normalization\n", + "transform = T.Compose([\n", + " T.Resize(800),\n", + " T.ToTensor(),\n", + " T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", + "])\n", + "\n", + "# for output bounding box post-processing\n", + "def box_cxcywh_to_xyxy(x):\n", + " x_c, y_c, w, h = x.unbind(1)\n", + " b = [(x_c - 0.5 * w), (y_c - 0.5 * h),\n", + " (x_c + 0.5 * w), (y_c + 0.5 * h)]\n", + " return torch.stack(b, dim=1)\n", + "\n", + "def rescale_bboxes(out_bbox, size):\n", + " img_w, img_h = size\n", + " b = box_cxcywh_to_xyxy(out_bbox)\n", + " b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)\n", + " return b\n", + "# colors for visualization\n", + "COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],\n", + " [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]\n", + "\n", + "def apply_mask(image, mask, color, alpha=0.5):\n", + " \"\"\"Apply the given mask to the image.\n", + " \"\"\"\n", + " for c in range(3):\n", + " image[:, :, c] = np.where(mask == 1,\n", + " image[:, :, c] *\n", + " (1 - alpha) + alpha * color[c] * 255,\n", + " image[:, :, c])\n", + " return image\n", + "\n", + "def plot_results(pil_img, scores, boxes, labels, masks=None):\n", + " plt.figure(figsize=(16,10))\n", + " np_image = np.array(pil_img)\n", + " ax = plt.gca()\n", + " colors = COLORS * 100\n", + " if masks is None:\n", + " masks = [None for _ in range(len(scores))]\n", + " assert len(scores) == len(boxes) == len(labels) == len(masks)\n", + " for s, (xmin, ymin, xmax, ymax), l, mask, c in zip(scores, boxes.tolist(), labels, masks, colors):\n", + " ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,\n", + " fill=False, color=c, linewidth=3))\n", + " text = f'{l}: {s:0.2f}'\n", + " ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='white', alpha=0.8))\n", + "\n", + " if mask is None:\n", + " continue\n", + " np_image = apply_mask(np_image, mask, c)\n", + "\n", + " padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8)\n", + " padded_mask[1:-1, 1:-1] = mask\n", + " contours = find_contours(padded_mask, 0.5)\n", + " for verts in contours:\n", + " # Subtract the padding and flip (y, x) to (x, y)\n", + " verts = np.fliplr(verts) - 1\n", + " p = Polygon(verts, facecolor=\"none\", edgecolor=c)\n", + " ax.add_patch(p)\n", + "\n", + "\n", + " plt.imshow(np_image)\n", + " plt.axis('off')\n", + " plt.savefig('foo.png',bbox_inches='tight')\n", + " return 'foo.png'\n", + "\n", + "\n", + "def add_res(results, ax, color='green'):\n", + " #for tt in results.values():\n", + " if True:\n", + " bboxes = results['boxes']\n", + " labels = results['labels']\n", + " scores = results['scores']\n", + " #keep = scores >= 0.0\n", + " #bboxes = bboxes[keep].tolist()\n", + " #labels = labels[keep].tolist()\n", + " #scores = scores[keep].tolist()\n", + " #print(torchvision.ops.box_iou(tt['boxes'].cpu().detach(), torch.as_tensor([[xmin, ymin, xmax, ymax]])))\n", + " \n", + " colors = ['purple', 'yellow', 'red', 'green', 'orange', 'pink']\n", + " \n", + " for i, (b, ll, ss) in enumerate(zip(bboxes, labels, scores)):\n", + " ax.add_patch(plt.Rectangle((b[0], b[1]), b[2] - b[0], b[3] - b[1], fill=False, color=colors[i], linewidth=3))\n", + " cls_name = ll if isinstance(ll,str) else CLASSES[ll]\n", + " text = f'{cls_name}: {ss:.2f}'\n", + " print(text)\n", + " ax.text(b[0], b[1], text, fontsize=15, bbox=dict(facecolor='white', alpha=0.8))\n", + "\n", + "\n", + "def plot_inference(im, caption, approaches):\n", + " \n", + " choices = {\"Worker Helmet Separately\" : 1,\"Worker Helmet Vest\":2, \"Workers only\":3}\n", + " \n", + " \n", + "# mean-std normalize the input image (batch-size: 1)\n", + " img = transform(im).unsqueeze(0).cpu()\n", + "\n", + " # propagate through the model\n", + " memory_cache = model2(img, [caption], encode_and_save=True)\n", + " outputs = model2(img, [caption], encode_and_save=False, memory_cache=memory_cache)\n", + "\n", + " # keep only predictions with 0.7+ confidence\n", + " probas = 1 - outputs['pred_logits'].softmax(-1)[0, :, -1].cpu()\n", + " keep = (probas > 0.7).cpu()\n", + "\n", + " # convert boxes from [0; 1] to image scales\n", + " bboxes_scaled = rescale_bboxes(outputs['pred_boxes'].cpu()[0, keep], im.size)\n", + "\n", + " # Extract the text spans predicted by each box\n", + " positive_tokens = (outputs[\"pred_logits\"].cpu()[0, keep].softmax(-1) > 0.1).nonzero().tolist()\n", + " predicted_spans = defaultdict(str)\n", + " for tok in positive_tokens:\n", + " item, pos = tok\n", + " if pos < 255:\n", + " span = memory_cache[\"tokenized\"].token_to_chars(0, pos)\n", + " predicted_spans [item] += \" \" + caption[span.start:span.end]\n", + "\n", + " labels = [predicted_spans [k] for k in sorted(list(predicted_spans .keys()))]\n", + " caption = 'Caption: '+ caption\n", + " return (sepia_call(caption, im, plot_results(im, probas[keep], bboxes_scaled, labels), choices[approaches]))\n", + " \n", + "\n", + "\n", + " \n", + "#BLIP Code\n", + "\n", + "\n", + "from modelsn.blip import blip_decoder\n", + "\n", + "image_size = 384\n", + "transform = transforms.Compose([\n", + " transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n", + " ]) \n", + "\n", + "model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'\n", + " \n", + "model = blip_decoder(pretrained=model_url, image_size=384, vit='base')\n", + "model.eval()\n", + "model = model.to(device)\n", + "\n", + "\n", + "from modelsn.blip_vqa import blip_vqa\n", + "\n", + "image_size_vq = 480\n", + "transform_vq = transforms.Compose([\n", + " transforms.Resize((image_size_vq,image_size_vq),interpolation=InterpolationMode.BICUBIC),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n", + " ]) \n", + "\n", + "model_url_vq = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'\n", + " \n", + "model_vq = blip_vqa(pretrained=model_url_vq, image_size=480, vit='base')\n", + "model_vq.eval()\n", + "model_vq = model_vq.to(device)\n", + "\n", + "\n", + "\n", + "def inference(raw_image, approaches, question):\n", + " \n", + "\n", + " image = transform(raw_image).unsqueeze(0).to(device) \n", + " with torch.no_grad():\n", + " caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)\n", + "\n", + " return (plot_inference(raw_image, caption[0], approaches))\n", + " #return 'caption: '+caption[0]\n", + "\n", + " \n", + "\n", + " \n", + "#PPE Detection code\n", + "import numpy as np\n", + "import run_code\n", + "import gradio as gr\n", + " \n", + "\n", + "def sepia_call(caption, Input_Image, MDETR_im, Approach):\n", + " pil_image = Input_Image\n", + " open_cv_image = np.asarray(pil_image)\n", + " sepia_img = run_code.run(open_cv_image, Approach)\n", + " images = sepia_img['img']\n", + " texts= sepia_img['text']\n", + "\n", + " return (caption, MDETR_im, images, texts)\n", + "\n", + "\n", + "inputs = [gr.inputs.Image(type='pil'),gr.inputs.Radio(choices=[\"Worker Helmet Separately\",\"Worker Helmet Vest\", \"Workers only\"], type=\"value\", default=\"Worker Helmet Vest\", label=\"Model\"),\"textbox\"]\n", + "outputs = [gr.outputs.Textbox(label=\"Output\"), \"image\", \"image\", gr.outputs.Textbox(label=\"Output\")]\n", + "\n", + "\n", + "title = \"BLIP + MDETR + PPE Detection\"\n", + "\n", + "description = \"Gradio demo for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation by Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.\"\n", + "\n", + "article = \"

BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation | Github Repo

\"\n", + "\n", + "\n", + "gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['starry.jpg',\"Image Captioning\",\"None\"]]).launch(share=True,enable_queue=True,cache_examples=False)" + ] + }, + { + "cell_type": "raw", + "id": "b2729aa9", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/configs/caption_coco.yaml b/configs/caption_coco.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b398665c6a047a0af48eb2c276b31653d548018b --- /dev/null +++ b/configs/caption_coco.yaml @@ -0,0 +1,33 @@ +image_root: '/export/share/datasets/vision/coco/images/' +ann_root: 'annotation' +coco_gt_root: 'annotation/coco_gt' + +# set pretrained as a file path or an url +pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth' + +# size of vit model; base or large +vit: 'base' +vit_grad_ckpt: False +vit_ckpt_layer: 0 +batch_size: 32 +init_lr: 1e-5 + +# vit: 'large' +# vit_grad_ckpt: True +# vit_ckpt_layer: 5 +# batch_size: 16 +# init_lr: 2e-6 + +image_size: 384 + +# generation configs +max_length: 20 +min_length: 5 +num_beams: 3 +prompt: 'a picture of ' + +# optimizer +weight_decay: 0.05 +min_lr: 0 +max_epoch: 5 + diff --git a/configs/med_config.json b/configs/med_config.json new file mode 100644 index 0000000000000000000000000000000000000000..0ffad0a6f3c2f9f11b8faa84529d9860bb70327a --- /dev/null +++ b/configs/med_config.json @@ -0,0 +1,21 @@ +{ + "architectures": [ + "BertModel" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 12, + "num_hidden_layers": 12, + "pad_token_id": 0, + "type_vocab_size": 2, + "vocab_size": 30524, + "encoder_width": 768, + "add_cross_attention": true +} diff --git a/configs/nlvr.yaml b/configs/nlvr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d1122aadb1a776bd347068233096b0c984f648b --- /dev/null +++ b/configs/nlvr.yaml @@ -0,0 +1,21 @@ +image_root: '/export/share/datasets/vision/NLVR2/' +ann_root: 'annotation' + +# set pretrained as a file path or an url +pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth' + +#size of vit model; base or large +vit: 'base' +batch_size_train: 16 +batch_size_test: 64 +vit_grad_ckpt: False +vit_ckpt_layer: 0 +max_epoch: 15 + +image_size: 384 + +# optimizer +weight_decay: 0.05 +init_lr: 3e-5 +min_lr: 0 + diff --git a/configs/nocaps.yaml b/configs/nocaps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..27bb115376bf45bc153bf6097b732da53ca77357 --- /dev/null +++ b/configs/nocaps.yaml @@ -0,0 +1,15 @@ +image_root: '/export/share/datasets/vision/nocaps/' +ann_root: 'annotation' + +# set pretrained as a file path or an url +pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth' + +vit: 'base' +batch_size: 32 + +image_size: 384 + +max_length: 20 +min_length: 5 +num_beams: 3 +prompt: 'a picture of ' \ No newline at end of file diff --git a/configs/pretrain.yaml b/configs/pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..02355ee0228932803c661616485bf315e862b826 --- /dev/null +++ b/configs/pretrain.yaml @@ -0,0 +1,27 @@ +train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json', + '/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json', + ] +laion_path: '' + +# size of vit model; base or large +vit: 'base' +vit_grad_ckpt: False +vit_ckpt_layer: 0 + +image_size: 224 +batch_size: 75 + +queue_size: 57600 +alpha: 0.4 + +# optimizer +weight_decay: 0.05 +init_lr: 3e-4 +min_lr: 1e-6 +warmup_lr: 1e-6 +lr_decay_rate: 0.9 +max_epoch: 20 +warmup_steps: 3000 + + + diff --git a/configs/retrieval_coco.yaml b/configs/retrieval_coco.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a8569e9b67112fe3605ac25e4fdc0231f7975378 --- /dev/null +++ b/configs/retrieval_coco.yaml @@ -0,0 +1,34 @@ +image_root: '/export/share/datasets/vision/coco/images/' +ann_root: 'annotation' +dataset: 'coco' + +# set pretrained as a file path or an url +pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth' + +# size of vit model; base or large + +vit: 'base' +batch_size_train: 32 +batch_size_test: 64 +vit_grad_ckpt: True +vit_ckpt_layer: 4 +init_lr: 1e-5 + +# vit: 'large' +# batch_size_train: 16 +# batch_size_test: 32 +# vit_grad_ckpt: True +# vit_ckpt_layer: 12 +# init_lr: 5e-6 + +image_size: 384 +queue_size: 57600 +alpha: 0.4 +k_test: 256 +negative_all_rank: True + +# optimizer +weight_decay: 0.05 +min_lr: 0 +max_epoch: 6 + diff --git a/configs/retrieval_flickr.yaml b/configs/retrieval_flickr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d75ea4eed87c9a001523c5e5914998c5e737594d --- /dev/null +++ b/configs/retrieval_flickr.yaml @@ -0,0 +1,34 @@ +image_root: '/export/share/datasets/vision/flickr30k/' +ann_root: 'annotation' +dataset: 'flickr' + +# set pretrained as a file path or an url +pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth' + +# size of vit model; base or large + +vit: 'base' +batch_size_train: 32 +batch_size_test: 64 +vit_grad_ckpt: True +vit_ckpt_layer: 4 +init_lr: 1e-5 + +# vit: 'large' +# batch_size_train: 16 +# batch_size_test: 32 +# vit_grad_ckpt: True +# vit_ckpt_layer: 10 +# init_lr: 5e-6 + +image_size: 384 +queue_size: 57600 +alpha: 0.4 +k_test: 128 +negative_all_rank: False + +# optimizer +weight_decay: 0.05 +min_lr: 0 +max_epoch: 6 + diff --git a/configs/vqa.yaml b/configs/vqa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..118f3968c7837010b7394a17f5cf2c6b62d5bc11 --- /dev/null +++ b/configs/vqa.yaml @@ -0,0 +1,25 @@ +vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/ +vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/ +train_files: ['vqa_train','vqa_val','vg_qa'] +ann_root: 'annotation' + +# set pretrained as a file path or an url +pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth' + +# size of vit model; base or large +vit: 'base' +batch_size_train: 16 +batch_size_test: 32 +vit_grad_ckpt: False +vit_ckpt_layer: 0 +init_lr: 2e-5 + +image_size: 480 + +k_test: 128 +inference: 'rank' + +# optimizer +weight_decay: 0.05 +min_lr: 0 +max_epoch: 10 \ No newline at end of file diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0be209acf415855ea6ef753efedf903b5decb6b9 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,101 @@ +import torch +from torch.utils.data import DataLoader +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode + +from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval +from data.nocaps_dataset import nocaps_eval +from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval +from data.vqa_dataset import vqa_dataset +from data.nlvr_dataset import nlvr_dataset +from data.pretrain_dataset import pretrain_dataset +from transform.randaugment import RandomAugment + +def create_dataset(dataset, config, min_scale=0.5): + + normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + + transform_train = transforms.Compose([ + transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC), + transforms.RandomHorizontalFlip(), + RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', + 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), + transforms.ToTensor(), + normalize, + ]) + transform_test = transforms.Compose([ + transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + normalize, + ]) + + if dataset=='pretrain': + dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train) + return dataset + + elif dataset=='caption_coco': + train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt']) + val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val') + test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test') + return train_dataset, val_dataset, test_dataset + + elif dataset=='nocaps': + val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val') + test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test') + return val_dataset, test_dataset + + elif dataset=='retrieval_coco': + train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root']) + val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') + test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') + return train_dataset, val_dataset, test_dataset + + elif dataset=='retrieval_flickr': + train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root']) + val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') + test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') + return train_dataset, val_dataset, test_dataset + + elif dataset=='vqa': + train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'], + train_files = config['train_files'], split='train') + test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test') + return train_dataset, test_dataset + + elif dataset=='nlvr': + train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train') + val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val') + test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test') + return train_dataset, val_dataset, test_dataset + + +def create_sampler(datasets, shuffles, num_tasks, global_rank): + samplers = [] + for dataset,shuffle in zip(datasets,shuffles): + sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) + samplers.append(sampler) + return samplers + + +def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): + loaders = [] + for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): + if is_train: + shuffle = (sampler is None) + drop_last = True + else: + shuffle = False + drop_last = False + loader = DataLoader( + dataset, + batch_size=bs, + num_workers=n_worker, + pin_memory=True, + sampler=sampler, + shuffle=shuffle, + collate_fn=collate_fn, + drop_last=drop_last, + ) + loaders.append(loader) + return loaders + diff --git a/data/coco_karpathy_dataset.py b/data/coco_karpathy_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a34d29205f42aa09695b160ac9c91958ba041bb3 --- /dev/null +++ b/data/coco_karpathy_dataset.py @@ -0,0 +1,126 @@ +import os +import json + +from torch.utils.data import Dataset +from torchvision.datasets.utils import download_url + +from PIL import Image + +from data.utils import pre_caption + +class coco_karpathy_train(Dataset): + def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): + ''' + image_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + ''' + url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json' + filename = 'coco_karpathy_train.json' + + download_url(url,ann_root) + + self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) + self.transform = transform + self.image_root = image_root + self.max_words = max_words + self.prompt = prompt + + self.img_ids = {} + n = 0 + for ann in self.annotation: + img_id = ann['image_id'] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, index): + + ann = self.annotation[index] + + image_path = os.path.join(self.image_root,ann['image']) + image = Image.open(image_path).convert('RGB') + image = self.transform(image) + + caption = self.prompt+pre_caption(ann['caption'], self.max_words) + + return image, caption, self.img_ids[ann['image_id']] + + +class coco_karpathy_caption_eval(Dataset): + def __init__(self, transform, image_root, ann_root, split): + ''' + image_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + ''' + urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', + 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} + filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} + + download_url(urls[split],ann_root) + + self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) + self.transform = transform + self.image_root = image_root + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, index): + + ann = self.annotation[index] + + image_path = os.path.join(self.image_root,ann['image']) + image = Image.open(image_path).convert('RGB') + image = self.transform(image) + + img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1] + + return image, int(img_id) + + +class coco_karpathy_retrieval_eval(Dataset): + def __init__(self, transform, image_root, ann_root, split, max_words=30): + ''' + image_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + ''' + urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', + 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} + filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} + + download_url(urls[split],ann_root) + + self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) + self.transform = transform + self.image_root = image_root + + self.text = [] + self.image = [] + self.txt2img = {} + self.img2txt = {} + + txt_id = 0 + for img_id, ann in enumerate(self.annotation): + self.image.append(ann['image']) + self.img2txt[img_id] = [] + for i, caption in enumerate(ann['caption']): + self.text.append(pre_caption(caption,max_words)) + self.img2txt[img_id].append(txt_id) + self.txt2img[txt_id] = img_id + txt_id += 1 + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, index): + + image_path = os.path.join(self.image_root, self.annotation[index]['image']) + image = Image.open(image_path).convert('RGB') + image = self.transform(image) + + return image, index \ No newline at end of file diff --git a/data/flickr30k_dataset.py b/data/flickr30k_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..018ab387014ddaf554c4d3184cfc0e2ba8b2d487 --- /dev/null +++ b/data/flickr30k_dataset.py @@ -0,0 +1,93 @@ +import os +import json + +from torch.utils.data import Dataset +from torchvision.datasets.utils import download_url + +from PIL import Image + +from data.utils import pre_caption + +class flickr30k_train(Dataset): + def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): + ''' + image_root (string): Root directory of images (e.g. flickr30k/) + ann_root (string): directory to store the annotation file + ''' + url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json' + filename = 'flickr30k_train.json' + + download_url(url,ann_root) + + self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) + self.transform = transform + self.image_root = image_root + self.max_words = max_words + self.prompt = prompt + + self.img_ids = {} + n = 0 + for ann in self.annotation: + img_id = ann['image_id'] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, index): + + ann = self.annotation[index] + + image_path = os.path.join(self.image_root,ann['image']) + image = Image.open(image_path).convert('RGB') + image = self.transform(image) + + caption = self.prompt+pre_caption(ann['caption'], self.max_words) + + return image, caption, self.img_ids[ann['image_id']] + + +class flickr30k_retrieval_eval(Dataset): + def __init__(self, transform, image_root, ann_root, split, max_words=30): + ''' + image_root (string): Root directory of images (e.g. flickr30k/) + ann_root (string): directory to store the annotation file + split (string): val or test + ''' + urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json', + 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'} + filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'} + + download_url(urls[split],ann_root) + + self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) + self.transform = transform + self.image_root = image_root + + self.text = [] + self.image = [] + self.txt2img = {} + self.img2txt = {} + + txt_id = 0 + for img_id, ann in enumerate(self.annotation): + self.image.append(ann['image']) + self.img2txt[img_id] = [] + for i, caption in enumerate(ann['caption']): + self.text.append(pre_caption(caption,max_words)) + self.img2txt[img_id].append(txt_id) + self.txt2img[txt_id] = img_id + txt_id += 1 + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, index): + + image_path = os.path.join(self.image_root, self.annotation[index]['image']) + image = Image.open(image_path).convert('RGB') + image = self.transform(image) + + return image, index \ No newline at end of file diff --git a/data/nlvr_dataset.py b/data/nlvr_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a8d6b2d7cd8d3260bd279c7dca80de53bacc691a --- /dev/null +++ b/data/nlvr_dataset.py @@ -0,0 +1,78 @@ +import os +import json +import random + +from torch.utils.data import Dataset +from torchvision.datasets.utils import download_url + +from PIL import Image + +from data.utils import pre_caption + +class nlvr_dataset(Dataset): + def __init__(self, transform, image_root, ann_root, split): + ''' + image_root (string): Root directory of images + ann_root (string): directory to store the annotation file + split (string): train, val or test + ''' + urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json', + 'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json', + 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'} + filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'} + + download_url(urls[split],ann_root) + self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) + + self.transform = transform + self.image_root = image_root + + + def __len__(self): + return len(self.annotation) + + + def __getitem__(self, index): + + ann = self.annotation[index] + + image0_path = os.path.join(self.image_root,ann['images'][0]) + image0 = Image.open(image0_path).convert('RGB') + image0 = self.transform(image0) + + image1_path = os.path.join(self.image_root,ann['images'][1]) + image1 = Image.open(image1_path).convert('RGB') + image1 = self.transform(image1) + + sentence = pre_caption(ann['sentence'], 40) + + if ann['label']=='True': + label = 1 + else: + label = 0 + + words = sentence.split(' ') + + if 'left' not in words and 'right' not in words: + if random.random()<0.5: + return image0, image1, sentence, label + else: + return image1, image0, sentence, label + else: + if random.random()<0.5: + return image0, image1, sentence, label + else: + new_words = [] + for word in words: + if word=='left': + new_words.append('right') + elif word=='right': + new_words.append('left') + else: + new_words.append(word) + + sentence = ' '.join(new_words) + return image1, image0, sentence, label + + + \ No newline at end of file diff --git a/data/nocaps_dataset.py b/data/nocaps_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ba0bed06d8af3dbaccf18a56e725f101e585503e --- /dev/null +++ b/data/nocaps_dataset.py @@ -0,0 +1,32 @@ +import os +import json + +from torch.utils.data import Dataset +from torchvision.datasets.utils import download_url + +from PIL import Image + +class nocaps_eval(Dataset): + def __init__(self, transform, image_root, ann_root, split): + urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json', + 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'} + filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'} + + download_url(urls[split],ann_root) + + self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) + self.transform = transform + self.image_root = image_root + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, index): + + ann = self.annotation[index] + + image_path = os.path.join(self.image_root,ann['image']) + image = Image.open(image_path).convert('RGB') + image = self.transform(image) + + return image, int(ann['img_id']) \ No newline at end of file diff --git a/data/pretrain_dataset.py b/data/pretrain_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..703d543ab5267fdc6fe2b7c84ef6a631d8af90ad --- /dev/null +++ b/data/pretrain_dataset.py @@ -0,0 +1,59 @@ +import json +import os +import random + +from torch.utils.data import Dataset + +from PIL import Image +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True +Image.MAX_IMAGE_PIXELS = None + +from data.utils import pre_caption +import os,glob + +class pretrain_dataset(Dataset): + def __init__(self, ann_file, laion_path, transform): + + self.ann_pretrain = [] + for f in ann_file: + print('loading '+f) + ann = json.load(open(f,'r')) + self.ann_pretrain += ann + + self.laion_path = laion_path + if self.laion_path: + self.laion_files = glob.glob(os.path.join(laion_path,'*.json')) + + print('loading '+self.laion_files[0]) + with open(self.laion_files[0],'r') as f: + self.ann_laion = json.load(f) + + self.annotation = self.ann_pretrain + self.ann_laion + else: + self.annotation = self.ann_pretrain + + self.transform = transform + + + def reload_laion(self, epoch): + n = epoch%len(self.laion_files) + print('loading '+self.laion_files[n]) + with open(self.laion_files[n],'r') as f: + self.ann_laion = json.load(f) + + self.annotation = self.ann_pretrain + self.ann_laion + + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, index): + + ann = self.annotation[index] + + image = Image.open(ann['image']).convert('RGB') + image = self.transform(image) + caption = pre_caption(ann['caption'],30) + + return image, caption \ No newline at end of file diff --git a/data/utils.py b/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..628894844becd462d444584b8b2b01a84ee4b8f7 --- /dev/null +++ b/data/utils.py @@ -0,0 +1,112 @@ +import re +import json +import os + +import torch +import torch.distributed as dist + +import utils + +def pre_caption(caption,max_words=50): + caption = re.sub( + r"([.!\"()*#:;~])", + ' ', + caption.lower(), + ) + caption = re.sub( + r"\s{2,}", + ' ', + caption, + ) + caption = caption.rstrip('\n') + caption = caption.strip(' ') + + #truncate caption + caption_words = caption.split(' ') + if len(caption_words)>max_words: + caption = ' '.join(caption_words[:max_words]) + + return caption + +def pre_question(question,max_ques_words=50): + question = re.sub( + r"([.!\"()*#:;~])", + '', + question.lower(), + ) + question = question.rstrip(' ') + + #truncate question + question_words = question.split(' ') + if len(question_words)>max_ques_words: + question = ' '.join(question_words[:max_ques_words]) + + return question + + +def save_result(result, result_dir, filename, remove_duplicate=''): + result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) + final_result_file = os.path.join(result_dir, '%s.json'%filename) + + json.dump(result,open(result_file,'w')) + + dist.barrier() + + if utils.is_main_process(): + # combine results from all processes + result = [] + + for rank in range(utils.get_world_size()): + result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) + res = json.load(open(result_file,'r')) + result += res + + if remove_duplicate: + result_new = [] + id_list = [] + for res in result: + if res[remove_duplicate] not in id_list: + id_list.append(res[remove_duplicate]) + result_new.append(res) + result = result_new + + json.dump(result,open(final_result_file,'w')) + print('result file saved to %s'%final_result_file) + + return final_result_file + + + +from pycocotools.coco import COCO +from pycocoevalcap.eval import COCOEvalCap +from torchvision.datasets.utils import download_url + +def coco_caption_eval(coco_gt_root, results_file, split): + urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json', + 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'} + filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'} + + download_url(urls[split],coco_gt_root) + annotation_file = os.path.join(coco_gt_root,filenames[split]) + + # create coco object and coco_result object + coco = COCO(annotation_file) + coco_result = coco.loadRes(results_file) + + # create coco_eval object by taking coco and coco_result + coco_eval = COCOEvalCap(coco, coco_result) + + # evaluate on a subset of images by setting + # coco_eval.params['image_id'] = coco_result.getImgIds() + # please remove this line when evaluating the full validation set + # coco_eval.params['image_id'] = coco_result.getImgIds() + + # evaluate results + # SPICE will take a few minutes the first time, but speeds up due to caching + coco_eval.evaluate() + + # print output evaluation scores + for metric, score in coco_eval.eval.items(): + print(f'{metric}: {score:.3f}') + + return coco_eval \ No newline at end of file diff --git a/data/vqa_dataset.py b/data/vqa_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..92ec1df429b3910316ddd554bfea01c6e7922cae --- /dev/null +++ b/data/vqa_dataset.py @@ -0,0 +1,88 @@ +import os +import json +import random +from PIL import Image + +import torch +from torch.utils.data import Dataset +from data.utils import pre_question + +from torchvision.datasets.utils import download_url + +class vqa_dataset(Dataset): + def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"): + self.split = split + + self.transform = transform + self.vqa_root = vqa_root + self.vg_root = vg_root + + if split=='train': + urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json', + 'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json', + 'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'} + + self.annotation = [] + for f in train_files: + download_url(urls[f],ann_root) + self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r')) + else: + download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root) + self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r')) + + download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root) + self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r')) + + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, index): + + ann = self.annotation[index] + + if ann['dataset']=='vqa': + image_path = os.path.join(self.vqa_root,ann['image']) + elif ann['dataset']=='vg': + image_path = os.path.join(self.vg_root,ann['image']) + + image = Image.open(image_path).convert('RGB') + image = self.transform(image) + + if self.split == 'test': + question = pre_question(ann['question']) + question_id = ann['question_id'] + return image, question, question_id + + + elif self.split=='train': + + question = pre_question(ann['question']) + + if ann['dataset']=='vqa': + answer_weight = {} + for answer in ann['answer']: + if answer in answer_weight.keys(): + answer_weight[answer] += 1/len(ann['answer']) + else: + answer_weight[answer] = 1/len(ann['answer']) + + answers = list(answer_weight.keys()) + weights = list(answer_weight.values()) + + elif ann['dataset']=='vg': + answers = [ann['answer']] + weights = [0.2] + + return image, question, answers, weights + + +def vqa_collate_fn(batch): + image_list, question_list, answer_list, weight_list, n = [], [], [], [], [] + for image, question, answer, weights in batch: + image_list.append(image) + question_list.append(question) + weight_list += weights + answer_list += answer + n.append(len(answer)) + return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n \ No newline at end of file diff --git a/elephant.jpg b/elephant.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8757046aed85d3e784db5f1f8f4c74e8bc906abd Binary files /dev/null and b/elephant.jpg differ diff --git a/eval_nocaps.py b/eval_nocaps.py new file mode 100644 index 0000000000000000000000000000000000000000..3cbb09a8cc7771605c013583d721aa95d9413b42 --- /dev/null +++ b/eval_nocaps.py @@ -0,0 +1,118 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li +''' +import argparse +import os +import ruamel_yaml as yaml +import numpy as np +import random +import time +import datetime +import json +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +import torch.distributed as dist +from torch.utils.data import DataLoader + +from models.blip import blip_decoder +import utils +from data import create_dataset, create_sampler, create_loader +from data.utils import save_result + +@torch.no_grad() +def evaluate(model, data_loader, device, config): + # evaluate + model.eval() + + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Evaluation:' + print_freq = 10 + + result = [] + for image, image_id in metric_logger.log_every(data_loader, print_freq, header): + + image = image.to(device) + + captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'], + min_length=config['min_length'], repetition_penalty=1.1) + + for caption, img_id in zip(captions, image_id): + result.append({"image_id": img_id.item(), "caption": caption}) + + return result + + +def main(args, config): + utils.init_distributed_mode(args) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + cudnn.benchmark = True + + #### Dataset #### + print("Creating captioning dataset") + val_dataset, test_dataset = create_dataset('nocaps', config) + + if args.distributed: + num_tasks = utils.get_world_size() + global_rank = utils.get_rank() + samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank) + else: + samplers = [None,None] + + val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers, + batch_size=[config['batch_size']]*2,num_workers=[4,4], + is_trains=[False, False], collate_fns=[None,None]) + + #### Model #### + print("Creating model") + model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'], + prompt=config['prompt']) + + model = model.to(device) + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + val_result = evaluate(model_without_ddp, val_loader, device, config) + val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id') + test_result = evaluate(model_without_ddp, test_loader, device, config) + test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', default='./configs/nocaps.yaml') + parser.add_argument('--output_dir', default='output/NoCaps') + parser.add_argument('--device', default='cuda') + parser.add_argument('--seed', default=42, type=int) + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + parser.add_argument('--distributed', default=True, type=bool) + args = parser.parse_args() + + config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) + + args.result_dir = os.path.join(args.output_dir, 'result') + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + Path(args.result_dir).mkdir(parents=True, exist_ok=True) + + yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) + + main(args, config) \ No newline at end of file diff --git a/examples/ex1.jpg b/examples/ex1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4567278f5ac74725c271cee5f2ac7efea1adba59 Binary files /dev/null and b/examples/ex1.jpg differ diff --git a/examples/ex2.jpg b/examples/ex2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..33cd44f62f9e45c3972906016359fa15867c465f Binary files /dev/null and b/examples/ex2.jpg differ diff --git a/examples/ex3.jpg b/examples/ex3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..803f5e0d103c9d29f069fa2ffeb89c2dc08eed5d Binary files /dev/null and b/examples/ex3.jpg differ diff --git a/extras/.DS_Store b/extras/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..8e845ba69d54565f56960f14a573352d0da756be Binary files /dev/null and b/extras/.DS_Store differ diff --git a/extras/sample-images/0.JPG b/extras/sample-images/0.JPG new file mode 100644 index 0000000000000000000000000000000000000000..772a9270032351b2941072cf1645b99fa9dee441 Binary files /dev/null and b/extras/sample-images/0.JPG differ diff --git a/extras/sample-images/1.JPG b/extras/sample-images/1.JPG new file mode 100644 index 0000000000000000000000000000000000000000..f05b4b06a33bb6a37ec2f3bc5c145c601e9a258d Binary files /dev/null and b/extras/sample-images/1.JPG differ diff --git a/extras/sample-images/10.jpg b/extras/sample-images/10.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c19cd692170921155b3c355501b86889452bcea4 Binary files /dev/null and b/extras/sample-images/10.jpg differ diff --git a/extras/sample-images/2.jpg b/extras/sample-images/2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3c230dbdf484c10173d583f589c1d64da8c71171 Binary files /dev/null and b/extras/sample-images/2.jpg differ diff --git a/extras/sample-images/3.jpg b/extras/sample-images/3.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b5ed5faa2b277a736fb5407d476a28bba55eb9cc Binary files /dev/null and b/extras/sample-images/3.jpg differ diff --git a/extras/sample-images/4.jpg b/extras/sample-images/4.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fc265890e1d07c6d29a4228f76f24f18e4a495ac Binary files /dev/null and b/extras/sample-images/4.jpg differ diff --git a/extras/sample-images/5.jpg b/extras/sample-images/5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..96f7c97f1ca973dab0a9bad5d41a6c391bc9335c Binary files /dev/null and b/extras/sample-images/5.jpg differ diff --git a/extras/sample-images/6.JPG b/extras/sample-images/6.JPG new file mode 100644 index 0000000000000000000000000000000000000000..ec413f60c9116af774fac5c921af10cfe7511b6b Binary files /dev/null and b/extras/sample-images/6.JPG differ diff --git a/extras/sample-images/7.JPG b/extras/sample-images/7.JPG new file mode 100644 index 0000000000000000000000000000000000000000..0afafcc5bf363cfa327129bf23e9f8e86c7946a1 Binary files /dev/null and b/extras/sample-images/7.JPG differ diff --git a/extras/sample-images/8.jpg b/extras/sample-images/8.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b52ac3f8ff9ea4d667517e009697e3b328fa12b6 Binary files /dev/null and b/extras/sample-images/8.jpg differ diff --git a/extras/sample-images/9.jpg b/extras/sample-images/9.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b14c3b1bd48296a06dc16ae0153d415b11e3e1cb Binary files /dev/null and b/extras/sample-images/9.jpg differ diff --git a/foo.png b/foo.png new file mode 100644 index 0000000000000000000000000000000000000000..57b652339a24a1f957855d0c55fa558d1fa0f187 Binary files /dev/null and b/foo.png differ diff --git a/gradio_cached_examples/log.csv b/gradio_cached_examples/log.csv new file mode 100644 index 0000000000000000000000000000000000000000..11484f1fffbeceae4cd87204b59428019ed25bbd --- /dev/null +++ b/gradio_cached_examples/log.csv @@ -0,0 +1,2 @@ +Output +caption: a painting of a starry night over a city diff --git a/local_run.ipynb b/local_run.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..3c0956b548675c1c9e8296efd77aab52d03813a8 --- /dev/null +++ b/local_run.ipynb @@ -0,0 +1,347 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on local URL: http://127.0.0.1:7860/\n", + "\n", + "To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(,\n", + " 'http://127.0.0.1:7860/',\n", + " None)" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2022-02-09 14:10:22.417549: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", + "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "\n", + "Total workers: 5\n", + "Number of Helmets: 4\n", + "Number of Vests: 0\n", + "dict vals:\n", + "{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 5\n", + "dict vals:\n", + "{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 5\n", + "Workers wearing helmet and vest: 0\n", + "Workers wearing only vest: 0\n", + "Workers wearing only helmet: 5\n", + "dict vals:\n", + "{'W': 5, 'WH': 5, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 5\n", + "Number of Helmets: 4\n", + "Number of Vests: 0\n", + "dict vals:\n", + "{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n", + "WARNING:tensorflow:5 out of the last 5 calls to .predict_function at 0x7fbc729998b0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", + "\n", + "\n", + "\n", + "Total workers: 5\n", + "Workers wearing helmet and vest: 0\n", + "Workers wearing only vest: 0\n", + "Workers wearing only helmet: 5\n", + "dict vals:\n", + "{'W': 5, 'WH': 5, 'WHV': 0, 'WV': 0}\n", + "WARNING:tensorflow:6 out of the last 6 calls to .predict_function at 0x7fbc979e9ee0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n", + "\n", + "\n", + "\n", + "Total workers: 3\n", + "dict vals:\n", + "{'W': 3, 'WH': 0, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 3\n", + "Workers wearing helmet and vest: 3\n", + "Workers wearing only vest: 0\n", + "Workers wearing only helmet: 0\n", + "dict vals:\n", + "{'W': 3, 'WH': 0, 'WHV': 3, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 3\n", + "Number of Helmets: 3\n", + "Number of Vests: 1\n", + "dict vals:\n", + "{'W': 3, 'WH': 0, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 5\n", + "Number of Helmets: 4\n", + "Number of Vests: 0\n", + "dict vals:\n", + "{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 6\n", + "Workers wearing helmet and vest: 0\n", + "Workers wearing only vest: 0\n", + "Workers wearing only helmet: 4\n", + "Workers not wearing helmet and vest: 2\n", + "\n", + "\n", + "dict vals:\n", + "{'W': 6, 'WH': 4, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 6\n", + "dict vals:\n", + "{'W': 6, 'WH': 0, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 5\n", + "Number of Helmets: 4\n", + "Number of Vests: 0\n", + "dict vals:\n", + "{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 6\n", + "dict vals:\n", + "{'W': 6, 'WH': 0, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 6\n", + "Workers wearing helmet and vest: 0\n", + "Workers wearing only vest: 0\n", + "Workers wearing only helmet: 4\n", + "Workers not wearing helmet and vest: 2\n", + "\n", + "\n", + "dict vals:\n", + "{'W': 6, 'WH': 4, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 1\n", + "Number of Helmets: 1\n", + "Number of Vests: 0\n", + "dict vals:\n", + "{'W': 1, 'WH': 0, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 1\n", + "Workers wearing helmet and vest: 0\n", + "Workers wearing only vest: 0\n", + "Workers wearing only helmet: 1\n", + "dict vals:\n", + "{'W': 1, 'WH': 1, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 1\n", + "dict vals:\n", + "{'W': 1, 'WH': 0, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 1\n", + "Workers wearing helmet and vest: 0\n", + "Workers wearing only vest: 0\n", + "Workers wearing only helmet: 1\n", + "dict vals:\n", + "{'W': 1, 'WH': 1, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 5\n", + "Number of Helmets: 4\n", + "Number of Vests: 0\n", + "dict vals:\n", + "{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 6\n", + "Workers wearing helmet and vest: 0\n", + "Workers wearing only vest: 0\n", + "Workers wearing only helmet: 4\n", + "Workers not wearing helmet and vest: 2\n", + "\n", + "\n", + "dict vals:\n", + "{'W': 6, 'WH': 4, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 3\n", + "Workers wearing helmet and vest: 3\n", + "Workers wearing only vest: 0\n", + "Workers wearing only helmet: 0\n", + "dict vals:\n", + "{'W': 3, 'WH': 0, 'WHV': 3, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 3\n", + "dict vals:\n", + "{'W': 3, 'WH': 0, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 3\n", + "Number of Helmets: 3\n", + "Number of Vests: 1\n", + "dict vals:\n", + "{'W': 3, 'WH': 0, 'WHV': 0, 'WV': 0}\n", + "\n", + "\n", + "\n", + "Total workers: 3\n", + "Workers wearing helmet and vest: 3\n", + "Workers wearing only vest: 0\n", + "Workers wearing only helmet: 0\n", + "dict vals:\n", + "{'W': 3, 'WH': 0, 'WHV': 3, 'WV': 0}\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import run_code\n", + "import cv2\n", + "import gradio as gr\n", + "\n", + "\n", + "def sepia(Input_Image, Approach):\n", + " pil_image = Input_Image\n", + " open_cv_image = np.asarray(pil_image)\n", + " # Convert RGB to BGR\n", + " #open_cv_image = open_cv_image[:, :, ::-1].copy()\n", + " #Approach = 3\n", + " sepia_img = run_code.run(open_cv_image, Approach)\n", + " images = sepia_img['img']\n", + " texts= sepia_img['text']\n", + " #print (labels)\n", + " return images, texts\n", + "\n", + "image = [gr.inputs.Image(type=\"pil\"), gr.inputs.Radio([1, 2, 3])]\n", + "#output = [\"image\", gr.outputs.Label(num_top_classes=4)]\n", + "output = [\"image\", gr.outputs.Textbox(type=\"auto\")]\n", + "#output = gr.outputs.Label(num_top_classes=4)\n", + "\n", + "title=\"Real-time Detection of Personal-Protective-Equipment (PPE)\"\n", + "description=\"This demo is the implementation of Real-time Detection of Personal-Protective-Equipment (PPE) paper https://github.com/ciber-lab/pictor-ppe\" \\\n", + " \" - by Sanjay Kamath \"\n", + "examples = [[\"examples/ex1.jpg\", 1], [\"examples/ex2.jpg\", 2], [\"examples/ex3.jpg\", 3]]\n", + "\n", + "#iface = gr.Interface(sepia , [ gr.inputs.Image(shape=(200, 200)), gr.inputs.Radio([1, 2, 3])], \"image\", title=title,\n", + "# examples = [[\"examples/ex1.jpg\"], [\"examples/ex2.jpg\"], [\"examples/ex3.jpg\"]],\n", + "# description=description)\n", + "\n", + "iface = gr.Interface(fn=sepia, inputs=image, outputs=output, title=title, description=description, examples=examples)\n", + "\n", + "iface.launch()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/model-data/.DS_Store b/model-data/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..413f62de131c21a2b324bf30aee88e5f1c4264e0 Binary files /dev/null and b/model-data/.DS_Store differ diff --git a/model-data/weights/pictor-ppe-v302-a1-yolo-v3-weights.h5 b/model-data/weights/pictor-ppe-v302-a1-yolo-v3-weights.h5 new file mode 100644 index 0000000000000000000000000000000000000000..10ffd04a67112b71ee728e867c5a5d8cee0df4aa --- /dev/null +++ b/model-data/weights/pictor-ppe-v302-a1-yolo-v3-weights.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3ec800aa5acdd9719ff5e63b34d1374e5c8a31e17f38f3a8250bf1aeeac1a972 +size 246910096 diff --git a/model-data/weights/pictor-ppe-v302-a2-yolo-v3-weights.h5 b/model-data/weights/pictor-ppe-v302-a2-yolo-v3-weights.h5 new file mode 100644 index 0000000000000000000000000000000000000000..0c0f5d084744391e54df6c72d13a3f252e25d5fa --- /dev/null +++ b/model-data/weights/pictor-ppe-v302-a2-yolo-v3-weights.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:317831ba378b8ec02e24e57859876eb0348284c8a75155143c9df85ee478c47b +size 246931600 diff --git a/model-data/weights/pictor-ppe-v302-a3-yolo-v3-weights.h5 b/model-data/weights/pictor-ppe-v302-a3-yolo-v3-weights.h5 new file mode 100644 index 0000000000000000000000000000000000000000..a98eb218705677fc8618e45284a04c7916ad27af --- /dev/null +++ b/model-data/weights/pictor-ppe-v302-a3-yolo-v3-weights.h5 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d06d4956d0f6b3ac71f02e103e9efdc4b222ce83aeae232f65ee6c04ee1dd2d7 +size 246867088 diff --git a/model-data/weights/readme.md b/model-data/weights/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..3ef4b304f4a53e4bfe5e20f0c2e21468f522fb2c --- /dev/null +++ b/model-data/weights/readme.md @@ -0,0 +1 @@ +Download the trained weights of YOLO models ([Google Drive folder](https://drive.google.com/drive/folders/13tCdROHnS0c5VibW1VO8pOEj0rXEvvGj?usp=sharing)) and put in this folder. diff --git a/modelsn/__init__.py b/modelsn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modelsn/__pycache__/__init__.cpython-38.pyc b/modelsn/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..804a103246f8061d570bb7da5e859107bd5d4810 Binary files /dev/null and b/modelsn/__pycache__/__init__.cpython-38.pyc differ diff --git a/modelsn/__pycache__/blip.cpython-38.pyc b/modelsn/__pycache__/blip.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5c8007ea324c7967f57e63f3da0a831c02466eb Binary files /dev/null and b/modelsn/__pycache__/blip.cpython-38.pyc differ diff --git a/modelsn/__pycache__/blip_vqa.cpython-38.pyc b/modelsn/__pycache__/blip_vqa.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db772d76dc35c405405dbe0cab595cfc79967b75 Binary files /dev/null and b/modelsn/__pycache__/blip_vqa.cpython-38.pyc differ diff --git a/modelsn/__pycache__/med.cpython-38.pyc b/modelsn/__pycache__/med.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00614c279380d1433fca29841fc58a37fcc5fd78 Binary files /dev/null and b/modelsn/__pycache__/med.cpython-38.pyc differ diff --git a/modelsn/__pycache__/vit.cpython-38.pyc b/modelsn/__pycache__/vit.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aed4ba4ca415c1f68c4953adb5b0c8f9b988f947 Binary files /dev/null and b/modelsn/__pycache__/vit.cpython-38.pyc differ diff --git a/modelsn/blip.py b/modelsn/blip.py new file mode 100644 index 0000000000000000000000000000000000000000..b861ada2ea6e17ef3439e583b44e61cc0614df05 --- /dev/null +++ b/modelsn/blip.py @@ -0,0 +1,238 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li +''' +import warnings +warnings.filterwarnings("ignore") + +from modelsn.vit import VisionTransformer, interpolate_pos_embed +from modelsn.med import BertConfig, BertModel, BertLMHeadModel +from transformers import BertTokenizer + +import torch +from torch import nn +import torch.nn.functional as F + +import os +from urllib.parse import urlparse +from timm.models.hub import download_cached_file + +class BLIP_Base(nn.Module): + def __init__(self, + med_config = 'configs/med_config.json', + image_size = 224, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) + self.tokenizer = init_tokenizer() + med_config = BertConfig.from_json_file(med_config) + med_config.encoder_width = vision_width + self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) + + + def forward(self, image, caption, mode): + + assert mode in ['image', 'text', 'multimodal'], "mode parameter must be image, text, or multimodal" + text = self.tokenizer(caption, return_tensors="pt").to(image.device) + + if mode=='image': + # return image features + image_embeds = self.visual_encoder(image) + return image_embeds + + elif mode=='text': + # return text features + text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, + return_dict = True, mode = 'text') + return text_output.last_hidden_state + + elif mode=='multimodal': + # return multimodel features + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + + text.input_ids[:,0] = self.tokenizer.enc_token_id + output = self.text_encoder(text.input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True, + ) + return output.last_hidden_state + + + +class BLIP_Decoder(nn.Module): + def __init__(self, + med_config = 'configs/med_config.json', + image_size = 384, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + prompt = 'a picture of ', + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) + self.tokenizer = init_tokenizer() + med_config = BertConfig.from_json_file(med_config) + med_config.encoder_width = vision_width + self.text_decoder = BertLMHeadModel(config=med_config) + + self.prompt = prompt + self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1 + + + def forward(self, image, caption): + + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + + text = self.tokenizer(caption, padding='longest', truncation=True, max_length=40, return_tensors="pt").to(image.device) + + text.input_ids[:,0] = self.tokenizer.bos_token_id + + decoder_targets = text.input_ids.masked_fill(text.input_ids == self.tokenizer.pad_token_id, -100) + decoder_targets[:,:self.prompt_length] = -100 + + decoder_output = self.text_decoder(text.input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + labels = decoder_targets, + return_dict = True, + ) + loss_lm = decoder_output.loss + + return loss_lm + + def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0): + image_embeds = self.visual_encoder(image) + + if not sample: + image_embeds = image_embeds.repeat_interleave(num_beams,dim=0) + + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + model_kwargs = {"encoder_hidden_states": image_embeds, "encoder_attention_mask":image_atts} + + prompt = [self.prompt] * image.size(0) + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device) + input_ids[:,0] = self.tokenizer.bos_token_id + input_ids = input_ids[:, :-1] + + if sample: + #nucleus sampling + outputs = self.text_decoder.generate(input_ids=input_ids, + max_length=max_length, + min_length=min_length, + do_sample=True, + top_p=top_p, + num_return_sequences=1, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + repetition_penalty=1.1, + **model_kwargs) + else: + #beam search + outputs = self.text_decoder.generate(input_ids=input_ids, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + repetition_penalty=repetition_penalty, + **model_kwargs) + + captions = [] + for output in outputs: + caption = self.tokenizer.decode(output, skip_special_tokens=True) + captions.append(caption[len(self.prompt):]) + return captions + + +def blip_decoder(pretrained='',**kwargs): + model = BLIP_Decoder(**kwargs) + if pretrained: + model,msg = load_checkpoint(model,pretrained) + assert(len(msg.missing_keys)==0) + return model + +def blip_feature_extractor(pretrained='',**kwargs): + model = BLIP_Base(**kwargs) + if pretrained: + model,msg = load_checkpoint(model,pretrained) + assert(len(msg.missing_keys)==0) + return model + +def init_tokenizer(): + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + tokenizer.add_special_tokens({'bos_token':'[DEC]'}) + tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']}) + tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] + return tokenizer + + +def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0): + + assert vit in ['base', 'large'], "vit parameter must be base or large" + if vit=='base': + vision_width = 768 + visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12, + num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, + drop_path_rate=0 or drop_path_rate + ) + elif vit=='large': + vision_width = 1024 + visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24, + num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer, + drop_path_rate=0.1 or drop_path_rate + ) + return visual_encoder, vision_width + +def is_url(url_or_filename): + parsed = urlparse(url_or_filename) + return parsed.scheme in ("http", "https") + +def load_checkpoint(model,url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) + checkpoint = torch.load(cached_file, map_location='cpu') + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location='cpu') + else: + raise RuntimeError('checkpoint url or path is invalid') + + state_dict = checkpoint['model'] + + state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) + if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): + state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'], + model.visual_encoder_m) + for key in model.state_dict().keys(): + if key in state_dict.keys(): + if state_dict[key].shape!=model.state_dict()[key].shape: + del state_dict[key] + + msg = model.load_state_dict(state_dict,strict=False) + print('load checkpoint from %s'%url_or_filename) + return model,msg + diff --git a/modelsn/blip_nlvr.py b/modelsn/blip_nlvr.py new file mode 100644 index 0000000000000000000000000000000000000000..95ad209b33b325bd8434f6809285bd2feedb90f4 --- /dev/null +++ b/modelsn/blip_nlvr.py @@ -0,0 +1,103 @@ +from modelsn.med import BertConfig +from modelsn.nlvr_encoder import BertModel +from modelsn.vit import interpolate_pos_embed +from modelsn.blip import create_vit, init_tokenizer, is_url + +from timm.models.hub import download_cached_file + +import torch +from torch import nn +import torch.nn.functional as F +from transformers import BertTokenizer +import numpy as np + +class BLIP_NLVR(nn.Module): + def __init__(self, + med_config = 'configs/med_config.json', + image_size = 480, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) + self.tokenizer = init_tokenizer() + med_config = BertConfig.from_json_file(med_config) + med_config.encoder_width = vision_width + self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) + + self.cls_head = nn.Sequential( + nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), + nn.ReLU(), + nn.Linear(self.text_encoder.config.hidden_size, 2) + ) + + def forward(self, image, text, targets, train=True): + + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0)) + + text = self.tokenizer(text, padding='longest', return_tensors="pt").to(image.device) + text.input_ids[:,0] = self.tokenizer.enc_token_id + + output = self.text_encoder(text.input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = [image0_embeds,image1_embeds], + encoder_attention_mask = [image_atts[:image0_embeds.size(0)], + image_atts[image0_embeds.size(0):]], + return_dict = True, + ) + hidden_state = output.last_hidden_state[:,0,:] + prediction = self.cls_head(hidden_state) + + if train: + loss = F.cross_entropy(prediction, targets) + return loss + else: + return prediction + +def blip_nlvr(pretrained='',**kwargs): + model = BLIP_NLVR(**kwargs) + if pretrained: + model,msg = load_checkpoint(model,pretrained) + print("missing keys:") + print(msg.missing_keys) + return model + + +def load_checkpoint(model,url_or_filename): + if is_url(url_or_filename): + cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True) + checkpoint = torch.load(cached_file, map_location='cpu') + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location='cpu') + else: + raise RuntimeError('checkpoint url or path is invalid') + state_dict = checkpoint['model'] + + state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) + + for key in list(state_dict.keys()): + if 'crossattention.self.' in key: + new_key0 = key.replace('self','self0') + new_key1 = key.replace('self','self1') + state_dict[new_key0] = state_dict[key] + state_dict[new_key1] = state_dict[key] + elif 'crossattention.output.dense.' in key: + new_key0 = key.replace('dense','dense0') + new_key1 = key.replace('dense','dense1') + state_dict[new_key0] = state_dict[key] + state_dict[new_key1] = state_dict[key] + + msg = model.load_state_dict(state_dict,strict=False) + print('load checkpoint from %s'%url_or_filename) + return model,msg + \ No newline at end of file diff --git a/modelsn/blip_pretrain.py b/modelsn/blip_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..e9f3b799aed7099e0daaa71521c10a42a0e413e9 --- /dev/null +++ b/modelsn/blip_pretrain.py @@ -0,0 +1,339 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li +''' +from modelsn.med import BertConfig, BertModel, BertLMHeadModel +from transformers import BertTokenizer +import transformers +transformers.logging.set_verbosity_error() + +import torch +from torch import nn +import torch.nn.functional as F + +from modelsn.blip import create_vit, init_tokenizer, load_checkpoint + +class BLIP_Pretrain(nn.Module): + def __init__(self, + med_config = 'configs/bert_config.json', + image_size = 224, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + embed_dim = 256, + queue_size = 57600, + momentum = 0.995, + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer, 0) + + if vit=='base': + checkpoint = torch.hub.load_state_dict_from_url( + url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", + map_location="cpu", check_hash=True) + state_dict = checkpoint["model"] + msg = self.visual_encoder.load_state_dict(state_dict,strict=False) + elif vit=='large': + from timm.models.helpers import load_custom_pretrained + from timm.models.vision_transformer import default_cfgs + load_custom_pretrained(self.visual_encoder,default_cfgs['vit_large_patch16_224_in21k']) + + self.tokenizer = init_tokenizer() + encoder_config = BertConfig.from_json_file(med_config) + encoder_config.encoder_width = vision_width + self.text_encoder = BertModel.from_pretrained('bert-base-uncased',config=encoder_config, add_pooling_layer=False) + self.text_encoder.resize_token_embeddings(len(self.tokenizer)) + + text_width = self.text_encoder.config.hidden_size + + self.vision_proj = nn.Linear(vision_width, embed_dim) + self.text_proj = nn.Linear(text_width, embed_dim) + + self.itm_head = nn.Linear(text_width, 2) + + # create momentum encoders + self.visual_encoder_m, vision_width = create_vit(vit,image_size) + self.vision_proj_m = nn.Linear(vision_width, embed_dim) + self.text_encoder_m = BertModel(config=encoder_config, add_pooling_layer=False) + self.text_proj_m = nn.Linear(text_width, embed_dim) + + self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], + [self.vision_proj,self.vision_proj_m], + [self.text_encoder,self.text_encoder_m], + [self.text_proj,self.text_proj_m], + ] + self.copy_params() + + # create the queue + self.register_buffer("image_queue", torch.randn(embed_dim, queue_size)) + self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) + self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) + + self.image_queue = nn.functional.normalize(self.image_queue, dim=0) + self.text_queue = nn.functional.normalize(self.text_queue, dim=0) + + self.queue_size = queue_size + self.momentum = momentum + self.temp = nn.Parameter(0.07*torch.ones([])) + + # create the decoder + decoder_config = BertConfig.from_json_file(med_config) + decoder_config.encoder_width = vision_width + self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config) + self.text_decoder.resize_token_embeddings(len(self.tokenizer)) + tie_encoder_decoder_weights(self.text_decoder.bert,self.text_encoder,'','/attention') + + + def forward(self, image, caption, alpha): + with torch.no_grad(): + self.temp.clamp_(0.001,0.5) + + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) + + text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=30, + return_tensors="pt").to(image.device) + text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, + return_dict = True, mode = 'text') + text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) + + # get momentum features + with torch.no_grad(): + self._momentum_update() + image_embeds_m = self.visual_encoder_m(image) + image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) + image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) + + text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask, + return_dict = True, mode = 'text') + text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) + text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) + + sim_i2t_m = image_feat_m @ text_feat_all / self.temp + sim_t2i_m = text_feat_m @ image_feat_all / self.temp + + sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device) + sim_targets.fill_diagonal_(1) + + sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets + sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets + + sim_i2t = image_feat @ text_feat_all / self.temp + sim_t2i = text_feat @ image_feat_all / self.temp + + loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() + loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() + + loss_ita = (loss_i2t+loss_t2i)/2 + + self._dequeue_and_enqueue(image_feat_m, text_feat_m) + + ###============== Image-text Matching ===================### + encoder_input_ids = text.input_ids.clone() + encoder_input_ids[:,0] = self.tokenizer.enc_token_id + + # forward the positve image-text pair + bs = image.size(0) + output_pos = self.text_encoder(encoder_input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True, + ) + with torch.no_grad(): + weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1)+1e-4 + weights_t2i.fill_diagonal_(0) + weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1)+1e-4 + weights_i2t.fill_diagonal_(0) + + # select a negative image for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg,dim=0) + + # select a negative text for each image + text_ids_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_ids_neg.append(encoder_input_ids[neg_idx]) + text_atts_neg.append(text.attention_mask[neg_idx]) + + text_ids_neg = torch.stack(text_ids_neg,dim=0) + text_atts_neg = torch.stack(text_atts_neg,dim=0) + + text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0) + text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0) + + image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0) + image_atts_all = torch.cat([image_atts,image_atts],dim=0) + + output_neg = self.text_encoder(text_ids_all, + attention_mask = text_atts_all, + encoder_hidden_states = image_embeds_all, + encoder_attention_mask = image_atts_all, + return_dict = True, + ) + + vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) + vl_output = self.itm_head(vl_embeddings) + + itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], + dim=0).to(image.device) + loss_itm = F.cross_entropy(vl_output, itm_labels) + + ##================= LM ========================## + decoder_input_ids = text.input_ids.clone() + decoder_input_ids[:,0] = self.tokenizer.bos_token_id + decoder_targets = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100) + + decoder_output = self.text_decoder(decoder_input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + labels = decoder_targets, + return_dict = True, + ) + + loss_lm = decoder_output.loss + return loss_ita, loss_itm, loss_lm + + + + @torch.no_grad() + def copy_params(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): + param_m.data.copy_(param.data) # initialize + param_m.requires_grad = False # not update by gradient + + + @torch.no_grad() + def _momentum_update(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): + param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) + + + @torch.no_grad() + def _dequeue_and_enqueue(self, image_feat, text_feat): + # gather keys before updating queue + image_feats = concat_all_gather(image_feat) + text_feats = concat_all_gather(text_feat) + + batch_size = image_feats.shape[0] + + ptr = int(self.queue_ptr) + assert self.queue_size % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.image_queue[:, ptr:ptr + batch_size] = image_feats.T + self.text_queue[:, ptr:ptr + batch_size] = text_feats.T + ptr = (ptr + batch_size) % self.queue_size # move pointer + + self.queue_ptr[0] = ptr + + +def blip_pretrain(**kwargs): + model = BLIP_Pretrain(**kwargs) + return model + + +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [torch.ones_like(tensor) + for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +from typing import List +def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str): + uninitialized_encoder_weights: List[str] = [] + if decoder.__class__ != encoder.__class__: + logger.info( + f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." + ) + + def tie_encoder_to_decoder_recursively( + decoder_pointer: nn.Module, + encoder_pointer: nn.Module, + module_name: str, + uninitialized_encoder_weights: List[str], + skip_key: str, + depth=0, + ): + assert isinstance(decoder_pointer, nn.Module) and isinstance( + encoder_pointer, nn.Module + ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" + if hasattr(decoder_pointer, "weight") and skip_key not in module_name: + assert hasattr(encoder_pointer, "weight") + encoder_pointer.weight = decoder_pointer.weight + if hasattr(decoder_pointer, "bias"): + assert hasattr(encoder_pointer, "bias") + encoder_pointer.bias = decoder_pointer.bias + print(module_name+' is tied') + return + + encoder_modules = encoder_pointer._modules + decoder_modules = decoder_pointer._modules + if len(decoder_modules) > 0: + assert ( + len(encoder_modules) > 0 + ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" + + all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()]) + encoder_layer_pos = 0 + for name, module in decoder_modules.items(): + if name.isdigit(): + encoder_name = str(int(name) + encoder_layer_pos) + decoder_name = name + if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len( + encoder_modules + ) != len(decoder_modules): + # this can happen if the name corresponds to the position in a list module list of layers + # in this case the decoder has added a cross-attention that the encoder does not have + # thus skip this step and subtract one layer pos from encoder + encoder_layer_pos -= 1 + continue + elif name not in encoder_modules: + continue + elif depth > 500: + raise ValueError( + "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." + ) + else: + decoder_name = encoder_name = name + tie_encoder_to_decoder_recursively( + decoder_modules[decoder_name], + encoder_modules[encoder_name], + module_name + "/" + name, + uninitialized_encoder_weights, + skip_key, + depth=depth + 1, + ) + all_encoder_weights.remove(module_name + "/" + encoder_name) + + uninitialized_encoder_weights += list(all_encoder_weights) + + # tie weights recursively + tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key) diff --git a/modelsn/blip_retrieval.py b/modelsn/blip_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..8becd3c7fd18201295cf54cb34de49914de3096a --- /dev/null +++ b/modelsn/blip_retrieval.py @@ -0,0 +1,322 @@ +from modelsn.med import BertConfig, BertModel +from transformers import BertTokenizer + +import torch +from torch import nn +import torch.nn.functional as F + +from modelsn.blip import create_vit, init_tokenizer, load_checkpoint + +class BLIP_Retrieval(nn.Module): + def __init__(self, + med_config = 'configs/med_config.json', + image_size = 384, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + embed_dim = 256, + queue_size = 57600, + momentum = 0.995, + negative_all_rank = False, + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) + self.tokenizer = init_tokenizer() + med_config = BertConfig.from_json_file(med_config) + med_config.encoder_width = vision_width + self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) + + text_width = self.text_encoder.config.hidden_size + + self.vision_proj = nn.Linear(vision_width, embed_dim) + self.text_proj = nn.Linear(text_width, embed_dim) + + self.itm_head = nn.Linear(text_width, 2) + + # create momentum encoders + self.visual_encoder_m, vision_width = create_vit(vit,image_size) + self.vision_proj_m = nn.Linear(vision_width, embed_dim) + self.text_encoder_m = BertModel(config=med_config, add_pooling_layer=False) + self.text_proj_m = nn.Linear(text_width, embed_dim) + + self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], + [self.vision_proj,self.vision_proj_m], + [self.text_encoder,self.text_encoder_m], + [self.text_proj,self.text_proj_m], + ] + self.copy_params() + + # create the queue + self.register_buffer("image_queue", torch.randn(embed_dim, queue_size)) + self.register_buffer("text_queue", torch.randn(embed_dim, queue_size)) + self.register_buffer("idx_queue", torch.full((1,queue_size),-100)) + self.register_buffer("ptr_queue", torch.zeros(1, dtype=torch.long)) + + self.image_queue = nn.functional.normalize(self.image_queue, dim=0) + self.text_queue = nn.functional.normalize(self.text_queue, dim=0) + + self.queue_size = queue_size + self.momentum = momentum + self.temp = nn.Parameter(0.07*torch.ones([])) + + self.negative_all_rank = negative_all_rank + + + def forward(self, image, caption, alpha, idx): + with torch.no_grad(): + self.temp.clamp_(0.001,0.5) + + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) + + text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, + return_tensors="pt").to(image.device) + + text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, + return_dict = True, mode = 'text') + text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) + + ###============== Image-text Contrastive Learning ===================### + idx = idx.view(-1,1) + idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1) + pos_idx = torch.eq(idx, idx_all).float() + sim_targets = pos_idx / pos_idx.sum(1,keepdim=True) + + # get momentum features + with torch.no_grad(): + self._momentum_update() + image_embeds_m = self.visual_encoder_m(image) + image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) + image_feat_m_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) + + text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask, + return_dict = True, mode = 'text') + text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) + text_feat_m_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) + + sim_i2t_m = image_feat_m @ text_feat_m_all / self.temp + sim_t2i_m = text_feat_m @ image_feat_m_all / self.temp + + sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device) + sim_targets.fill_diagonal_(1) + + sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets + sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets + + sim_i2t = image_feat @ text_feat_m_all / self.temp + sim_t2i = text_feat @ image_feat_m_all / self.temp + + loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() + loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() + + loss_ita = (loss_i2t+loss_t2i)/2 + + idxs = concat_all_gather(idx) + self._dequeue_and_enqueue(image_feat_m, text_feat_m, idxs) + + ###============== Image-text Matching ===================### + encoder_input_ids = text.input_ids.clone() + encoder_input_ids[:,0] = self.tokenizer.enc_token_id + + # forward the positve image-text pair + bs = image.size(0) + output_pos = self.text_encoder(encoder_input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True, + ) + + + if self.negative_all_rank: + # compute sample similarity + with torch.no_grad(): + mask = torch.eq(idx, idxs.t()) + + image_feat_world = concat_all_gather(image_feat) + text_feat_world = concat_all_gather(text_feat) + + sim_i2t = image_feat @ text_feat_world.t() / self.temp + sim_t2i = text_feat @ image_feat_world.t() / self.temp + + weights_i2t = F.softmax(sim_i2t,dim=1) + weights_i2t.masked_fill_(mask, 0) + + weights_t2i = F.softmax(sim_t2i,dim=1) + weights_t2i.masked_fill_(mask, 0) + + image_embeds_world = all_gather_with_grad(image_embeds) + + # select a negative image (from all ranks) for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds_world[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg,dim=0) + + # select a negative text (from all ranks) for each image + input_ids_world = concat_all_gather(encoder_input_ids) + att_mask_world = concat_all_gather(text.attention_mask) + + text_ids_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_ids_neg.append(input_ids_world[neg_idx]) + text_atts_neg.append(att_mask_world[neg_idx]) + + else: + with torch.no_grad(): + mask = torch.eq(idx, idx.t()) + + sim_i2t = image_feat @ text_feat.t() / self.temp + sim_t2i = text_feat @ image_feat.t() / self.temp + + weights_i2t = F.softmax(sim_i2t,dim=1) + weights_i2t.masked_fill_(mask, 0) + + weights_t2i = F.softmax(sim_t2i,dim=1) + weights_t2i.masked_fill_(mask, 0) + + # select a negative image (from same rank) for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg,dim=0) + + # select a negative text (from same rank) for each image + text_ids_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_ids_neg.append(encoder_input_ids[neg_idx]) + text_atts_neg.append(text.attention_mask[neg_idx]) + + text_ids_neg = torch.stack(text_ids_neg,dim=0) + text_atts_neg = torch.stack(text_atts_neg,dim=0) + + text_ids_all = torch.cat([encoder_input_ids, text_ids_neg],dim=0) + text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0) + + image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0) + image_atts_all = torch.cat([image_atts,image_atts],dim=0) + + output_neg = self.text_encoder(text_ids_all, + attention_mask = text_atts_all, + encoder_hidden_states = image_embeds_all, + encoder_attention_mask = image_atts_all, + return_dict = True, + ) + + + vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) + vl_output = self.itm_head(vl_embeddings) + + itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], + dim=0).to(image.device) + loss_itm = F.cross_entropy(vl_output, itm_labels) + + return loss_ita, loss_itm + + + @torch.no_grad() + def copy_params(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): + param_m.data.copy_(param.data) # initialize + param_m.requires_grad = False # not update by gradient + + + @torch.no_grad() + def _momentum_update(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): + param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) + + + @torch.no_grad() + def _dequeue_and_enqueue(self, image_feat, text_feat, idxs): + # gather keys before updating queue + image_feats = concat_all_gather(image_feat) + text_feats = concat_all_gather(text_feat) + + + batch_size = image_feats.shape[0] + + ptr = int(self.ptr_queue) + assert self.queue_size % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.image_queue[:, ptr:ptr + batch_size] = image_feats.T + self.text_queue[:, ptr:ptr + batch_size] = text_feats.T + self.idx_queue[:, ptr:ptr + batch_size] = idxs.T + ptr = (ptr + batch_size) % self.queue_size # move pointer + + self.ptr_queue[0] = ptr + + +def blip_retrieval(pretrained='',**kwargs): + model = BLIP_Retrieval(**kwargs) + if pretrained: + model,msg = load_checkpoint(model,pretrained) + print("missing keys:") + print(msg.missing_keys) + return model + + +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [torch.ones_like(tensor) + for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +class GatherLayer(torch.autograd.Function): + """ + Gather tensors from all workers with support for backward propagation: + This implementation does not cut the gradients as torch.distributed.all_gather does. + """ + + @staticmethod + def forward(ctx, x): + output = [torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(output, x) + return tuple(output) + + @staticmethod + def backward(ctx, *grads): + all_gradients = torch.stack(grads) + torch.distributed.all_reduce(all_gradients) + return all_gradients[torch.distributed.get_rank()] + + +def all_gather_with_grad(tensors): + """ + Performs all_gather operation on the provided tensors. + Graph remains connected for backward grad computation. + """ + # Queue the gathered tensors + world_size = torch.distributed.get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + + tensor_all = GatherLayer.apply(tensors) + + return torch.cat(tensor_all, dim=0) diff --git a/modelsn/blip_vqa.py b/modelsn/blip_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..25ebae6a9f9c83fc7f5d2225fd1e36e321488283 --- /dev/null +++ b/modelsn/blip_vqa.py @@ -0,0 +1,186 @@ +from modelsn.med import BertConfig, BertModel, BertLMHeadModel +from modelsn.blip import create_vit, init_tokenizer, load_checkpoint + +import torch +from torch import nn +import torch.nn.functional as F +from transformers import BertTokenizer +import numpy as np + +class BLIP_VQA(nn.Module): + def __init__(self, + med_config = 'configs/med_config.json', + image_size = 480, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit, image_size, vit_grad_ckpt, vit_ckpt_layer, drop_path_rate=0.1) + self.tokenizer = init_tokenizer() + + encoder_config = BertConfig.from_json_file(med_config) + encoder_config.encoder_width = vision_width + self.text_encoder = BertModel(config=encoder_config, add_pooling_layer=False) + + decoder_config = BertConfig.from_json_file(med_config) + self.text_decoder = BertLMHeadModel(config=decoder_config) + + + def forward(self, image, question, answer=None, n=None, weights=None, train=True, inference='rank', k_test=128): + + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + + question = self.tokenizer(question, padding='longest', truncation=True, max_length=35, + return_tensors="pt").to(image.device) + question.input_ids[:,0] = self.tokenizer.enc_token_id + + if train: + ''' + n: number of answers for each question + weights: weight for each answer + ''' + answer = self.tokenizer(answer, padding='longest', return_tensors="pt").to(image.device) + answer.input_ids[:,0] = self.tokenizer.bos_token_id + answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100) + + question_output = self.text_encoder(question.input_ids, + attention_mask = question.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True) + + question_states = [] + question_atts = [] + for b, n in enumerate(n): + question_states += [question_output.last_hidden_state[b]]*n + question_atts += [question.attention_mask[b]]*n + question_states = torch.stack(question_states,0) + question_atts = torch.stack(question_atts,0) + + answer_output = self.text_decoder(answer.input_ids, + attention_mask = answer.attention_mask, + encoder_hidden_states = question_states, + encoder_attention_mask = question_atts, + labels = answer_targets, + return_dict = True, + reduction = 'none', + ) + + loss = weights * answer_output.loss + loss = loss.sum()/image.size(0) + + return loss + + + else: + question_output = self.text_encoder(question.input_ids, + attention_mask = question.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True) + + if inference=='generate': + num_beams = 3 + question_states = question_output.last_hidden_state.repeat_interleave(num_beams,dim=0) + question_atts = torch.ones(question_states.size()[:-1],dtype=torch.long).to(question_states.device) + model_kwargs = {"encoder_hidden_states": question_states, "encoder_attention_mask":question_atts} + + bos_ids = torch.full((image.size(0),1),fill_value=self.tokenizer.bos_token_id,device=image.device) + + outputs = self.text_decoder.generate(input_ids=bos_ids, + max_length=10, + min_length=1, + num_beams=num_beams, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + **model_kwargs) + + answers = [] + for output in outputs: + answer = self.tokenizer.decode(output, skip_special_tokens=True) + answers.append(answer) + return answers + + elif inference=='rank': + max_ids = self.rank_answer(question_output.last_hidden_state, question.attention_mask, + answer.input_ids, answer.attention_mask, k_test) + return max_ids + + + + def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k): + + num_ques = question_states.size(0) + start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token + + start_output = self.text_decoder(start_ids, + encoder_hidden_states = question_states, + encoder_attention_mask = question_atts, + return_dict = True, + reduction = 'none') + logits = start_output.logits[:,0,:] # first token's logit + + # topk_probs: top-k probability + # topk_ids: [num_question, k] + answer_first_token = answer_ids[:,1] + prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token) + topk_probs, topk_ids = prob_first_token.topk(k,dim=1) + + # answer input: [num_question*k, answer_len] + input_ids = [] + input_atts = [] + for b, topk_id in enumerate(topk_ids): + input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) + input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) + input_ids = torch.cat(input_ids,dim=0) + input_atts = torch.cat(input_atts,dim=0) + + targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100) + + # repeat encoder's output for top-k answers + question_states = tile(question_states, 0, k) + question_atts = tile(question_atts, 0, k) + + output = self.text_decoder(input_ids, + attention_mask = input_atts, + encoder_hidden_states = question_states, + encoder_attention_mask = question_atts, + labels = targets_ids, + return_dict = True, + reduction = 'none') + + log_probs_sum = -output.loss + log_probs_sum = log_probs_sum.view(num_ques,k) + + max_topk_ids = log_probs_sum.argmax(dim=1) + max_ids = topk_ids[max_topk_ids>=0,max_topk_ids] + + return max_ids + + +def blip_vqa(pretrained='',**kwargs): + model = BLIP_VQA(**kwargs) + if pretrained: + model,msg = load_checkpoint(model,pretrained) +# assert(len(msg.missing_keys)==0) + return model + + +def tile(x, dim, n_tile): + init_dim = x.size(dim) + repeat_idx = [1] * x.dim() + repeat_idx[dim] = n_tile + x = x.repeat(*(repeat_idx)) + order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) + return torch.index_select(x, dim, order_index.to(x.device)) + + \ No newline at end of file diff --git a/modelsn/med.py b/modelsn/med.py new file mode 100644 index 0000000000000000000000000000000000000000..7b00a35450b736180a805d4f4664b4fb95aeba01 --- /dev/null +++ b/modelsn/med.py @@ -0,0 +1,955 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +''' + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if self.config.add_cross_attention: + self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if mode=='multimodal': + assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" + + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multimodal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, + device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction='mean', + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + mode=mode, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + if reduction=='none': + lm_loss = lm_loss.view(prediction_scores.size(0),-1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past diff --git a/modelsn/nlvr_encoder.py b/modelsn/nlvr_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1946bb4a300f75afa4848f6622839445903c34a9 --- /dev/null +++ b/modelsn/nlvr_encoder.py @@ -0,0 +1,843 @@ +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + + +logger = logging.get_logger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + + self.config = config + + def forward( + self, input_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + embeddings = inputs_embeds + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config, twin=False, merge=False): + super().__init__() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if twin: + self.dense0 = nn.Linear(config.hidden_size, config.hidden_size) + self.dense1 = nn.Linear(config.hidden_size, config.hidden_size) + else: + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if merge: + self.act = ACT2FN[config.hidden_act] + self.merge_layer = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.merge = True + else: + self.merge = False + + def forward(self, hidden_states, input_tensor): + if type(hidden_states) == list: + hidden_states0 = self.dense0(hidden_states[0]) + hidden_states1 = self.dense1(hidden_states[1]) + if self.merge: + #hidden_states = self.merge_layer(self.act(torch.cat([hidden_states0,hidden_states1],dim=-1))) + hidden_states = self.merge_layer(torch.cat([hidden_states0,hidden_states1],dim=-1)) + else: + hidden_states = (hidden_states0+hidden_states1)/2 + else: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_num=-1): + super().__init__() + if is_cross_attention: + self.self0 = BertSelfAttention(config, is_cross_attention) + self.self1 = BertSelfAttention(config, is_cross_attention) + else: + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config, twin=is_cross_attention, merge=(is_cross_attention and layer_num>=6)) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + if type(encoder_hidden_states)==list: + self_outputs0 = self.self0( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states[0], + encoder_attention_mask[0], + past_key_value, + output_attentions, + ) + self_outputs1 = self.self1( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states[1], + encoder_attention_mask[1], + past_key_value, + output_attentions, + ) + attention_output = self.output([self_outputs0[0],self_outputs1[0]], hidden_states) + + outputs = (attention_output,) + self_outputs0[1:] # add attentions if we output them + else: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if self.config.add_cross_attention: + self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention, layer_num=layer_num) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + mode=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + + if mode=='multimodal': + assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([BertLayer(config,i) for i in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + mode='multimodal', + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + mode=mode, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + mode=mode, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + + def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + encoder_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + mode='multimodal', + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + device = input_ids.device + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = inputs_embeds.device + elif encoder_embeds is not None: + input_shape = encoder_embeds.size()[:-1] + batch_size, seq_length = input_shape + device = encoder_embeds.device + else: + raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, + device, is_decoder) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() + else: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if encoder_embeds is None: + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + else: + embedding_output = encoder_embeds + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + mode=mode, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + diff --git a/modelsn/vit.py b/modelsn/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..cec3d8e08ed4451d65392feb2e9f4848d1ef3899 --- /dev/null +++ b/modelsn/vit.py @@ -0,0 +1,305 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on timm code base + * https://github.com/rwightman/pytorch-image-models/tree/master/timm +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.models.vision_transformer import _cfg, PatchEmbed +from timm.models.registry import register_model +from timm.models.layers import trunc_normal_, DropPath +from timm.models.helpers import named_apply, adapt_input_conv + +from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.attn_gradients = None + self.attention_map = None + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def forward(self, x, register_hook=False): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + if register_hook: + self.save_attention_map(attn) + attn.register_hook(self.save_attn_gradients) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if use_grad_checkpointing: + self.attn = checkpoint_wrapper(self.attn) + self.mlp = checkpoint_wrapper(self.mlp) + + def forward(self, x, register_hook=False): + x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class VisionTransformer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, + use_grad_checkpointing=False, ckpt_layer=0): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer: (nn.Module): normalization layer + """ + super().__init__() + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer) + ) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + trunc_normal_(self.pos_embed, std=.02) + trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward(self, x, register_blk=-1): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + self.pos_embed[:,:x.size(1),:] + x = self.pos_drop(x) + + for i,blk in enumerate(self.blocks): + x = blk(x, register_blk==i) + x = self.norm(x) + + return x + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: + w = w.flatten() + if t: + if w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix and 'opt/target/embedding/kernel' in w: + prefix = 'opt/target/' + + if hasattr(model.patch_embed, 'backbone'): + # hybrid + backbone = model.patch_embed.backbone + stem_only = not hasattr(backbone, 'stem') + stem = backbone if stem_only else backbone.stem + stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) + stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) + stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) + if not stem_only: + for i, stage in enumerate(backbone.stages): + for j, block in enumerate(stage.blocks): + bp = f'{prefix}block{i + 1}/unit{j + 1}/' + for r in range(3): + getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) + getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) + getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) + if block.downsample is not None: + block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) + block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) + block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) + embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) + else: + embed_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(embed_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) + if pos_embed_w.shape != model.pos_embed.shape: + pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights + pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + model.pos_embed.copy_(pos_embed_w) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) +# if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: +# model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) +# model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) +# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: +# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) +# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + for r in range(2): + getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) + getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) + + +def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): + # interpolate position embedding + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = visual_encoder.patch_embed.num_patches + num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + + if orig_size!=new_size: + # class_token and dist_token are kept unchanged + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) + + return new_pos_embed + else: + return pos_embed_checkpoint \ No newline at end of file diff --git a/pretrain.py b/pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..c9490ec8eb8ff5f074b5772ada55cd27ec673a12 --- /dev/null +++ b/pretrain.py @@ -0,0 +1,173 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li +''' +import argparse +import os +import ruamel_yaml as yaml +import numpy as np +import random +import time +import datetime +import json +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +import torch.distributed as dist +from torch.utils.data import DataLoader + +from models.blip_pretrain import blip_pretrain +import utils +from utils import warmup_lr_schedule, step_lr_schedule +from data import create_dataset, create_sampler, create_loader + +def train(model, data_loader, optimizer, epoch, device, config): + # train + model.train() + + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) + metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) + metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) + metric_logger.add_meter('loss_lm', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) + + header = 'Train Epoch: [{}]'.format(epoch) + print_freq = 50 + + if config['laion_path']: + data_loader.dataset.reload_laion(epoch) + + data_loader.sampler.set_epoch(epoch) + + for i, (image, caption) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): + + if epoch==0: + warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr']) + + optimizer.zero_grad() + + image = image.to(device,non_blocking=True) + + # ramp up alpha in the first 2 epochs + alpha = config['alpha']*min(1,(epoch*len(data_loader)+i)/(2*len(data_loader))) + + loss_ita, loss_itm, loss_lm = model(image, caption, alpha = alpha) + loss = loss_ita + loss_itm + loss_lm + + loss.backward() + optimizer.step() + + metric_logger.update(loss_ita=loss_ita.item()) + metric_logger.update(loss_itm=loss_itm.item()) + metric_logger.update(loss_lm=loss_lm.item()) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger.global_avg()) + return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} + + +def main(args, config): + utils.init_distributed_mode(args) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + cudnn.benchmark = True + + #### Dataset #### + print("Creating dataset") + datasets = [create_dataset('pretrain', config, min_scale=0.2)] + print('number of training samples: %d'%len(datasets[0])) + + num_tasks = utils.get_world_size() + global_rank = utils.get_rank() + samplers = create_sampler(datasets, [True], num_tasks, global_rank) + + data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0] + + #### Model #### + print("Creating model") + model = blip_pretrain(image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], + vit_ckpt_layer=config['vit_ckpt_layer'], queue_size=config['queue_size']) + + model = model.to(device) + + optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) + + start_epoch = 0 + if args.checkpoint: + checkpoint = torch.load(args.checkpoint, map_location='cpu') + state_dict = checkpoint['model'] + model.load_state_dict(state_dict) + + optimizer.load_state_dict(checkpoint['optimizer']) + start_epoch = checkpoint['epoch']+1 + print('resume checkpoint from %s'%args.checkpoint) + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + print("Start training") + start_time = time.time() + for epoch in range(start_epoch, config['max_epoch']): + + step_lr_schedule(optimizer, epoch, config['init_lr'], config['min_lr'], config['lr_decay_rate']) + + train_stats = train(model, data_loader, optimizer, epoch, device, config) + if utils.is_main_process(): + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + 'epoch': epoch, + } + save_obj = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'config': config, + 'epoch': epoch, + } + torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch)) + + with open(os.path.join(args.output_dir, "log.txt"),"a") as f: + f.write(json.dumps(log_stats) + "\n") + + dist.barrier() + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', default='./configs/pretrain.yaml') + parser.add_argument('--output_dir', default='output/Pretrain') + parser.add_argument('--checkpoint', default='') + parser.add_argument('--evaluate', action='store_true') + parser.add_argument('--device', default='cuda') + parser.add_argument('--seed', default=42, type=int) + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + parser.add_argument('--distributed', default=True, type=bool) + args = parser.parse_args() + + config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) + + main(args, config) \ No newline at end of file diff --git a/requirements 2.txt b/requirements 2.txt new file mode 100644 index 0000000000000000000000000000000000000000..6ec7753a301409ee15eb5f25416799e4b652ff2d --- /dev/null +++ b/requirements 2.txt @@ -0,0 +1,2 @@ +tensorflow +opencv-python-headless diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f52c0584abc2e80da29a6bc327f4892b5c4011ae --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +timm==0.4.12 +transformers==4.15.0 +fairscale==0.4.4 +pycocoevalcap +torch +torchvision +Pillow \ No newline at end of file diff --git a/run_code.py b/run_code.py new file mode 100644 index 0000000000000000000000000000000000000000..57da809cafad370be0c6ade23ab5ed875b1a6972 --- /dev/null +++ b/run_code.py @@ -0,0 +1,161 @@ + +from tensorflow.keras.layers import Input + +from src.yolo3.model import * +from src.yolo3.detect import * + +from src.utils.image import * +from src.utils.datagen import * +from src.utils.fixes import * + +fix_tf_gpu() +def prepare_model(approach): + ''' + Prepare the YOLO model + ''' + global input_shape, class_names, anchor_boxes, num_classes, num_anchors, model + + # shape (height, width) of the imput image + input_shape = (416, 416) + + # class names + if approach == 1: + class_names = ['H', 'V', 'W'] + + elif approach == 2: + class_names = ['W','WH','WV','WHV'] + + elif approach == 3: + class_names = ['W'] + + else: + raise NotImplementedError('Approach should be 1, 2, or 3') + + # anchor boxes + if approach == 1: + anchor_boxes = np.array( + [ + np.array([[ 76, 59], [ 84, 136], [188, 225]]) /32, # output-1 anchor boxes + np.array([[ 25, 15], [ 46, 29], [ 27, 56]]) /16, # output-2 anchor boxes + np.array([[ 5, 3], [ 10, 8], [ 12, 26]]) /8 # output-3 anchor boxes + ], + dtype='float64' + ) + else: + anchor_boxes = np.array( + [ + np.array([[ 73, 158], [128, 209], [224, 246]]) /32, # output-1 anchor boxes + np.array([[ 32, 50], [ 40, 104], [ 76, 73]]) /16, # output-2 anchor boxes + np.array([[ 6, 11], [ 11, 23], [ 19, 36]]) /8 # output-3 anchor boxes + ], + dtype='float64' + ) + + # number of classes and number of anchors + num_classes = len(class_names) + num_anchors = anchor_boxes.shape[0] * anchor_boxes.shape[1] + + # input and output + input_tensor = Input( shape=(input_shape[0], input_shape[1], 3) ) # input + num_out_filters = ( num_anchors//3 ) * ( 5 + num_classes ) # output + + # build the model + model = yolo_body(input_tensor, num_out_filters) + + # load weights + weight_path = f'model-data/weights/pictor-ppe-v302-a{approach}-yolo-v3-weights.h5' + model.load_weights( weight_path ) + + +def get_detection(img): + # save a copy of the img + act_img = img.copy() + + # shape of the image + ih, iw = act_img.shape[:2] + + # preprocess the image + img = letterbox_image(img, input_shape) + img = np.expand_dims(img, 0) + image_data = np.array(img) / 255. + + # raw prediction from yolo model + prediction = model.predict(image_data) + + # process the raw prediction to get the bounding boxes + boxes = detection( + prediction, + anchor_boxes, + num_classes, + image_shape=(ih, iw), + input_shape=(416, 416), + max_boxes=10, + score_threshold=0.3, + iou_threshold=0.45, + classes_can_overlap=False) + + # convert tensor to numpy + boxes = boxes[0].numpy() + + # draw the detection on the actual image + return (draw_detection(act_img, boxes, class_names), boxes) + + + +def run (image_in, approach): + prepare_model(approach=approach) + +# input_shape = (416, 416) + img = letterbox_image(image_in, input_shape) + + # get the detection on the image + img, all_classes = get_detection(img) + + #print (all_classes) + WHV = 0 + WV = 0 + WH = 0 + W = 0 + H = 0 + V = 0 + for i in all_classes: + if class_names[int(i[-1])] == "WHV": + WHV += 1 + W += 1 + elif class_names[int(i[-1])] == "WH": + WH += 1 + W += 1 + elif class_names[int(i[-1])] == "H": + H += 1 + elif class_names[int(i[-1])] == "V": + V += 1 + elif class_names[int(i[-1])] == "WV": + WV += 1 + W += 1 + elif class_names[int(i[-1])] == "W": + W += 1 + + #Outputs to display the number of each classes in an interpretable format + texts = "" + texts = texts + "Total workers: " + str(W) + "\n" + if approach != 3: + if approach == 1: + texts = texts + "Number of Helmets: " + str(H) + "\n" + texts = texts + "Number of Vests: " + str(V) + "\n" + + elif approach == 2: + + texts = texts + "Workers wearing helmet and vest: " + str(WHV) + "\n" + texts = texts + "Workers wearing only vest: " + str(WV) + "\n" + texts = texts + "Workers wearing only helmet: " + str(WH) + "\n" + + if (W > WHV) and (WHV != 0): + texts = texts + "Workers not wearing helmet and vest: " + str(W - WHV) + "\n" + + if (W > WH) and (WH != 0): + texts = texts + "Workers not wearing helmet and vest: " + str(W - WH) + "\n" + + if (W > WV) and (WV != 0): + texts = texts + "Workers not wearing helmet and vest: " + str(W - WV) + "\n" + + return {'img': img[:, :, ::-1], 'text': texts} diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..4039553bd8de38c3108dbfc6b5dbd11875e42157 Binary files /dev/null and b/src/.DS_Store differ diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/__pycache__/__init__.cpython-37.pyc b/src/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6e4fd59821da7a6ba0e40a5d60f0f629909b7eb Binary files /dev/null and b/src/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/__pycache__/__init__.cpython-38.pyc b/src/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c2af633a206e0eb91db7cbf5359601bd8010975 Binary files /dev/null and b/src/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/utils/.DS_Store b/src/utils/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..b94bb04f9e8c39a690c7cb2d5a11b7bd81303cf0 Binary files /dev/null and b/src/utils/.DS_Store differ diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/utils/__pycache__/__init__.cpython-37.pyc b/src/utils/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3564586737aa2293b0348c59f5ebc18ea4c3a2c6 Binary files /dev/null and b/src/utils/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/utils/__pycache__/__init__.cpython-38.pyc b/src/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9584ea1a5105129789d2daac34d69811b5bc8b0 Binary files /dev/null and b/src/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/utils/__pycache__/datagen.cpython-37.pyc b/src/utils/__pycache__/datagen.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0e025b16e99975f85e2ab44eb4db2e88cf698a7 Binary files /dev/null and b/src/utils/__pycache__/datagen.cpython-37.pyc differ diff --git a/src/utils/__pycache__/datagen.cpython-38.pyc b/src/utils/__pycache__/datagen.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..471fb2e238f6cccee542a15500a69adecb047c35 Binary files /dev/null and b/src/utils/__pycache__/datagen.cpython-38.pyc differ diff --git a/src/utils/__pycache__/fixes.cpython-37.pyc b/src/utils/__pycache__/fixes.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acfee22e035bcb840357381a59de4ccda0721498 Binary files /dev/null and b/src/utils/__pycache__/fixes.cpython-37.pyc differ diff --git a/src/utils/__pycache__/fixes.cpython-38.pyc b/src/utils/__pycache__/fixes.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..391147ff96453a781c04a26b7414fc949d6d8248 Binary files /dev/null and b/src/utils/__pycache__/fixes.cpython-38.pyc differ diff --git a/src/utils/__pycache__/image.cpython-37.pyc b/src/utils/__pycache__/image.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a539b52f2e5864593a548a85532de2a6b82a45a Binary files /dev/null and b/src/utils/__pycache__/image.cpython-37.pyc differ diff --git a/src/utils/__pycache__/image.cpython-38.pyc b/src/utils/__pycache__/image.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81832323d2276c33f9d93e2a1244d9edf4d5871c Binary files /dev/null and b/src/utils/__pycache__/image.cpython-38.pyc differ diff --git a/src/utils/datagen.py b/src/utils/datagen.py new file mode 100644 index 0000000000000000000000000000000000000000..4f48c07c72dde7959aea7341ebd92d97f39ca018 --- /dev/null +++ b/src/utils/datagen.py @@ -0,0 +1,159 @@ +import cv2 +import numpy as np + +from matplotlib.colors import rgb_to_hsv, hsv_to_rgb + + +def rand(a=0, b=1): + return np.random.rand()*(b-a) + a + + +def get_random_data ( + annotation_line, + input_shape, + max_boxes=25, + scale=.3, + hue=.1, + sat=1.5, + val=1.5, + random=True + ): + + ''' + random preprocessing for real-time data augmentation + ''' + + line = annotation_line.split('\t') + h, w = input_shape + box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]]) + + image = cv2.imread(line[0]) + ih, iw, ic = image.shape + + if not random: + resize_scale = min(h/ih, w/iw) + + nw = int(iw * resize_scale) + nh = int(ih * resize_scale) + + max_offx = w - nw + max_offy = h - nh + + dx = max_offx//2 + dy = max_offy//2 + + to_x0, to_y0 = max(0, dx), max(0, dy) + from_x0, from_y0 = max(0, -dx), max(0, -dy) + wx, hy = min(w, dx+nw) - to_x0, min(h, dy+nh) - to_y0 + + # place image + image_data = np.zeros((*input_shape,ic), dtype='uint8') + 128 + image_data[to_y0:to_y0+hy, to_x0:to_x0+wx, :] = cv2.resize(image, (nw, nh))[from_y0:from_y0+hy, from_x0:from_x0+wx, :] + + flip = False + image_data = image_data/255. + else: + if np.random.uniform() >= 0.5: + # scale Up + resize_scale = 1. + scale * np.random.uniform() + resize_scale = max( h*resize_scale/ih, w*resize_scale/iw) + + nw = int(iw * resize_scale) + nh = int(ih * resize_scale) + + max_offx = nw - w + max_offy = nh - h + + dx = int(np.random.uniform() * max_offx) + dy = int(np.random.uniform() * max_offy) + + # resize and crop + image = cv2.resize(image, (nw, nh)) + image_data = image[dy : (dy + h), dx : (dx + w), :] + + dx, dy = (-dx, -dy) + else: + # scale down + mul = 1 if np.random.uniform() >= 0.5 else -1 + + resize_scale = 1. + mul * scale * np.random.uniform() + resize_scale = min( h*resize_scale/ih, w*resize_scale/iw) + + nw = int(iw * resize_scale) + nh = int(ih * resize_scale) + + max_offx = w - nw + max_offy = h - nh + + dx = int(np.random.uniform() * max_offx) + dy = int(np.random.uniform() * max_offy) + + to_x0, to_y0 = max(0, dx), max(0, dy) + from_x0, from_y0 = max(0, -dx), max(0, -dy) + wx, hy = min(w, dx+nw) - to_x0, min(h, dy+nh) - to_y0 + + # place image + image_data = np.zeros((*input_shape,ic), dtype='uint8') + 128 + image_data[to_y0:to_y0+hy, to_x0:to_x0+wx, :] = cv2.resize(image, (nw, nh))[from_y0:from_y0+hy, from_x0:from_x0+wx, :] + + flip = np.random.uniform() >= 0.5 + if flip: image_data = image_data[:,::-1,:] + + # distort color of the image + hue = rand(-hue, hue) + sat = rand(1, sat) if rand()<.5 else 1/rand(1, sat) + val = rand(1, val) if rand()<.5 else 1/rand(1, val) + x = rgb_to_hsv(np.array(image_data)/255.) + x[..., 0] += hue + x[..., 0][x[..., 0]>1] -= 1 + x[..., 0][x[..., 0]<0] += 1 + x[..., 1] *= sat + x[..., 2] *= val + x[x>1] = 1 + x[x<0] = 0 + image_data = hsv_to_rgb(x) # numpy array, 0 to 1 + + # correct boxes + box_data = np.zeros((max_boxes,5)) + if len(box)>0: + np.random.shuffle(box) + box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx + box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy + if flip: box[:, [0,2]] = w - box[:, [2,0]] + box[:, 0:2][box[:, 0:2]<0] = 0 + box[:, 2][box[:, 2]>w] = w + box[:, 3][box[:, 3]>h] = h + box_w = box[:, 2] - box[:, 0] + box_h = box[:, 3] - box[:, 1] + box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box + if len(box)>max_boxes: box = box[:max_boxes] + box_data[:len(box)] = box + + return image_data, box_data + + +def data_generator(annotation_lines, batch_size, input_shape, random): + ''' + data generator for fit_generator + ''' + n = len(annotation_lines) + i = 0 + while True: + image_data = [] + box_data = [] + for _ in range(batch_size): + image, box = get_random_data(annotation_lines[i], input_shape, max_boxes=50, random=random) + image_data.append(image) + box = box[np.sum(box, axis=1) != 0, :] + box_data.append(box) + i = (i+1) % n + image_data = np.array(image_data) + box_data = np.array(box_data) + + yield image_data, box_data + + +def data_generator_wrapper(annotation_lines, batch_size, input_shape, random): + n = len(annotation_lines) + if n==0 or batch_size<=0: return None + return data_generator(annotation_lines, batch_size, input_shape, random) \ No newline at end of file diff --git a/src/utils/fixes.py b/src/utils/fixes.py new file mode 100644 index 0000000000000000000000000000000000000000..89ab964cd4ca4ba6f447a5d7878cddb9b6b2131f --- /dev/null +++ b/src/utils/fixes.py @@ -0,0 +1,19 @@ +import tensorflow as tf + + +def fix_tf_gpu(): + ''' + Fix for the following error message: + UnknownError: Failed to get convolution algorithm. + This is probably because cuDNN failed to initialize... + + More: + https://www.tensorflow.org/api_docs/python/tf/config/experimental/set_memory_growth + ''' + + physical_devices = tf.config.experimental.list_physical_devices('GPU') + + try: + tf.config.experimental.set_memory_growth(physical_devices[0], True) + except: + pass \ No newline at end of file diff --git a/src/utils/image.py b/src/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..36e654f388239956f4ebc780381f9e4d1ffe1c69 --- /dev/null +++ b/src/utils/image.py @@ -0,0 +1,82 @@ +import cv2 +import numpy as np +import matplotlib as mpl + + +def letterbox_image(image, size): + ''' + Resize image with unchanged aspect ratio using padding + ''' + + # original image size + ih, iw, ic = image.shape + + # given size + h, w = size + + # scale and new size of the image + scale = min(w/iw, h/ih) + nw = int(iw*scale) + nh = int(ih*scale) + + # placeholder letter box + new_image = np.zeros((h, w, ic), dtype='uint8') + 128 + + # top-left corner + top, left = (h - nh)//2, (w - nw)//2 + + # paste the scaled image in the placeholder anchoring at the top-left corner + new_image[top:top+nh, left:left+nw, :] = cv2.resize(image, (nw, nh)) + + return new_image + + +def draw_detection( + img, + boxes, + class_names, + # drawing configs + font=cv2.FONT_HERSHEY_DUPLEX, + font_scale=0.5, + box_thickness=2, + border=5, + text_color=(255, 255, 255), + text_weight=1 +): + ''' + Draw the bounding boxes on the image + ''' + # generate some colors for different classes + num_classes = len(class_names) # number of classes + colors = [mpl.colors.hsv_to_rgb((i/num_classes, 1, 1)) * 255 for i in range(num_classes)] + + # draw the detections + for box in boxes: + x1, y1, x2, y2 = box[:4].astype(int) + score = box[-2] + label = int(box[-1]) + + clr = colors[label] + + # draw the bounding box + img = cv2.rectangle(img, (x1, y1), (x2, y2), clr, box_thickness) + + # text: (%) + text = f'{class_names[label]} ({score*100:.0f}%)' + + # get width (tw) and height (th) of the text + (tw, th), _ = cv2.getTextSize(text, font, font_scale, 1) + + # background rectangle for the text + tb_x1 = x1 - box_thickness//2 + tb_y1 = y1 - box_thickness//2 - th - 2*border + tb_x2 = x1 + tw + 2*border + tb_y2 = y1 + + # draw the background rectangle + img = cv2.rectangle(img, (tb_x1, tb_y1), (tb_x2, tb_y2), clr, -1) + + # put the text + img = cv2.putText(img, text, (x1 + border, y1 - border), font, font_scale, text_color, text_weight, cv2.LINE_AA) + + return img \ No newline at end of file diff --git a/src/yolo3/.DS_Store b/src/yolo3/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..255f9536f02b9749148f119e6c8b71356c33ed7c Binary files /dev/null and b/src/yolo3/.DS_Store differ diff --git a/src/yolo3/__init__.py b/src/yolo3/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/yolo3/__pycache__/__init__.cpython-37.pyc b/src/yolo3/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..830b27861f5193e3937bf8d221a6a58bc33bda7d Binary files /dev/null and b/src/yolo3/__pycache__/__init__.cpython-37.pyc differ diff --git a/src/yolo3/__pycache__/__init__.cpython-38.pyc b/src/yolo3/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1893d765cd2cc9e8aebb4a04866abbe968f1dab8 Binary files /dev/null and b/src/yolo3/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/yolo3/__pycache__/detect.cpython-37.pyc b/src/yolo3/__pycache__/detect.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbd968ee727a4475449d112a314c188546d05190 Binary files /dev/null and b/src/yolo3/__pycache__/detect.cpython-37.pyc differ diff --git a/src/yolo3/__pycache__/detect.cpython-38.pyc b/src/yolo3/__pycache__/detect.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33f0e658f3a0e6009e77110ed17e89fd02fed1ec Binary files /dev/null and b/src/yolo3/__pycache__/detect.cpython-38.pyc differ diff --git a/src/yolo3/__pycache__/model.cpython-37.pyc b/src/yolo3/__pycache__/model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..899c5c7e569efc0247278b5861f932d74ff2140e Binary files /dev/null and b/src/yolo3/__pycache__/model.cpython-37.pyc differ diff --git a/src/yolo3/__pycache__/model.cpython-38.pyc b/src/yolo3/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09965d08c232485ff642e8cceefed01edb971f35 Binary files /dev/null and b/src/yolo3/__pycache__/model.cpython-38.pyc differ diff --git a/src/yolo3/detect.py b/src/yolo3/detect.py new file mode 100644 index 0000000000000000000000000000000000000000..656b97d65600d47518893101c461755ea5d57b17 --- /dev/null +++ b/src/yolo3/detect.py @@ -0,0 +1,170 @@ +import numpy as np +import tensorflow as tf + + +def detection( + prediction, + anchor_boxes, + num_classes, + image_shape, + input_shape, + max_boxes = 20, + score_threshold=0.3, + iou_threshold=0.45, + classes_can_overlap=True, +): + ''' + INPUT: + OUTPUT: + ''' + + all_boxes = [] + + '''@ Each output layer''' + for output, anchors in zip( prediction, anchor_boxes ): + + '''Preprocessing''' + '''-------------''' + # shapes + batch_size = output.shape[0] + grid_h, grid_w = output.shape[1:3] + + # reshape to [batch_size, grid_height, grid_width, num_anchors, box_params] + output = tf.reshape( output, [ -1, grid_h, grid_w, len(anchors), num_classes+5 ] ) + + # create a tensor for the anchor boxes + anchors_tensor = tf.constant(anchors, dtype=output.dtype) + + '''Scaling factors''' + '''---------------''' + image_shape_tensor = tf.cast( image_shape, output.dtype ) # actual image's shape + grids_shape_tensor = tf.cast( output.shape[1:3], output.dtype ) # grid_height, grid_width @ output layer + input_shape_tensor = tf.cast( input_shape, output.dtype ) # yolo input image's shape + + # reshape + image_shape_tensor = tf.reshape( image_shape_tensor, [-1, 1, 1, 1, 2] ) + grids_shape_tensor = tf.reshape( grids_shape_tensor, [-1, 1, 1, 1, 2] ) + input_shape_tensor = tf.reshape( input_shape_tensor, [-1, 1, 1, 1, 2] ) + + ### Scaling factors + sized_shape_tensor = tf.round( image_shape_tensor * tf.reshape( tf.reduce_min( input_shape_tensor / image_shape_tensor, axis=-1 ), [-1,1,1,1,1] ) ) + # to scale the boxes from grid's unit to actual image's pixel unit + box_scaling = input_shape_tensor * image_shape_tensor / sized_shape_tensor / grids_shape_tensor + # to offset the boxes + box_offsets = (tf.expand_dims(tf.reduce_max(image_shape_tensor, axis=-1), axis=-1) - image_shape_tensor) / 2. + + '''Box geometric properties''' + '''------------------------''' + grid_h, grid_w = output.shape[1:3] # grid_height, grid_width @ output layer + + grid_i = tf.reshape( np.arange(grid_h), [-1, 1, 1, 1] ) + grid_i = tf.tile( grid_i, [1, grid_w, 1, 1] ) + + grid_j = tf.reshape( np.arange(grid_w), [1, -1, 1, 1] ) + grid_j = tf.tile( grid_j, [grid_h, 1, 1, 1] ) + + grid_ji = tf.concat( [grid_j, grid_i], axis=-1 ) + grid_ji = tf.cast( grid_ji, output.dtype ) + + # Box centers + box_xy = output[..., 0:2] + box_xy = tf.sigmoid( box_xy ) + grid_ji + + # Box sizes + box_wh = output[..., 2:4] + box_wh = tf.exp( box_wh ) * anchors_tensor + + # scale to actual pixel unit + box_xy = box_xy * box_scaling - box_offsets[...,::-1] + box_wh = box_wh * box_scaling + + # calculate top-left corner (x1, y1) and bottom-right corner (x2, y2) of the boxex + box_x1_y1 = box_xy - box_wh / 2 + box_x2_y2 = box_xy + box_wh / 2 + + # top-left corner cannot be negative + box_x1_y1 = tf.maximum(0, box_x1_y1) + # bottom-right corner cannot be more than actual image size + box_x2_y2 = tf.minimum(box_x2_y2, image_shape_tensor[..., ::-1]) + + '''Box labels and confidences''' + '''--------------------------''' + # class probabilities = objectness score * conditional class probabilities + if classes_can_overlap: + # use sigmoid for the conditional class probabilities + classs_probs = tf.sigmoid( output[..., 4:5] ) * tf.sigmoid( output[..., 5:] ) + else: + # use softmax for the conditional class probabilities + classs_probs = tf.sigmoid( output[..., 4:5] ) * tf.nn.softmax( output[..., 5:] ) + + box_cl = tf.argmax( classs_probs, axis=-1 ) # final classes + box_sc = tf.reduce_max( classs_probs, axis=-1 ) # confidence scores + + '''Organize''' + '''--------''' + # take care of dtype and dimensions + box_cl = tf.cast( box_cl, output.dtype ) + box_cl = tf.expand_dims(box_cl, axis=-1) + box_sc = tf.expand_dims(box_sc, axis=-1) + + # store all information as: [ left(x1), top(y1), right(x2), bottom(y2), confidence, label ] + boxes = tf.reshape( tf.concat( [ box_x1_y1, box_x2_y2, box_sc, box_cl ], axis=-1 ), + [batch_size, -1, 6] ) + + all_boxes. append( boxes ) + + # Merge across all output layers + all_boxes = tf.concat( all_boxes, axis=1 ) + + # To store all the final results of all images in the batch + all_final_boxes = [] + + '''For each image in the batch''' + for _boxes_ in all_boxes: + + if classes_can_overlap: + '''Perform NMS for each class individually''' + + # to stote the final results of this image + final_boxes = [] + + for class_id in range(num_classes): + + # Get the boxes and scores for this class + class_boxes = _boxes_[ _boxes_[...,-1] == class_id ] + + '''Non-max-suppression''' + selected_idc = tf.image.non_max_suppression( + class_boxes[...,:4], # boxes' (y1,x1,y2,x2) + class_boxes[...,-2], # boxes' scores + max_output_size = max_boxes, + iou_threshold = iou_threshold, + score_threshold = score_threshold + ) + + # boxes selected by nms + class_boxes = tf.gather( class_boxes, selected_idc ) + final_boxes.append( class_boxes ) + + # concatenate boxes for each class in the image + final_boxes = tf.concat( final_boxes, axis=0 ) + + else: + '''Perform NMS for all classes''' + + # nms indices + selected_idc = tf.image.non_max_suppression( + _boxes_[...,:4], # boxes' (y1,x1,y2,x2) + _boxes_[...,-2], # boxes' scores + max_output_size = max_boxes, + iou_threshold = iou_threshold, + score_threshold = score_threshold + ) + + # boxes selected by nms + final_boxes = tf.gather( _boxes_, selected_idc ) + + # append final boxes for each image in the batch + all_final_boxes.append( final_boxes ) + + return all_final_boxes \ No newline at end of file diff --git a/src/yolo3/model.py b/src/yolo3/model.py new file mode 100644 index 0000000000000000000000000000000000000000..0df1d54d28dc10ddac4d58b1bec11a42077fd8b2 --- /dev/null +++ b/src/yolo3/model.py @@ -0,0 +1,95 @@ +import tensorflow as tf +import tensorflow.keras.backend as K + +from tensorflow.keras.layers import Input, Conv2D, Add, ZeroPadding2D, UpSampling2D, Concatenate, MaxPooling2D +from tensorflow.keras.layers import LeakyReLU, BatchNormalization +from tensorflow.keras.models import Sequential, Model +from tensorflow.keras.regularizers import l2 + +'''====================================================================================================''' +'''BLOCKS''' + +'''Convolutional Block''' +def yolo_ConvBlock (input_tensor, num_filters, filter_size, strides = (1,1) ): + padding = 'valid' if strides == (2,2) else 'same' + + ### Layers + x = Conv2D( num_filters, filter_size, strides, padding, use_bias=False, kernel_regularizer=l2(5e-4) ) (input_tensor) + x = BatchNormalization() (x) + x = LeakyReLU(alpha=0.1) (x) + + return x + +'''Residual Block''' +def yolo_ResidualBlocks (input_tensor, num_filters, num_blocks ): + + ### Layers + x = ZeroPadding2D( ((1,0),(1,0)) ) (input_tensor) # left & top padding + x = yolo_ConvBlock ( x, num_filters, filter_size=(3,3), strides = (2,2) ) + + for _ in range( num_blocks ): + y = yolo_ConvBlock ( x, num_filters//2, filter_size=(1,1), strides = (1,1) ) + y = yolo_ConvBlock ( y, num_filters , filter_size=(3,3), strides = (1,1) ) + x = Add() ([x, y]) + + return x + +'''Output Block''' +def yolo_OutputBlock (x, num_filters, out_filters ): + + ### Layers + x = yolo_ConvBlock ( x, 1*num_filters, filter_size=(1,1), strides = (1,1) ) + x = yolo_ConvBlock ( x, 2*num_filters, filter_size=(3,3), strides = (1,1) ) + x = yolo_ConvBlock ( x, 1*num_filters, filter_size=(1,1), strides = (1,1) ) + x = yolo_ConvBlock ( x, 2*num_filters, filter_size=(3,3), strides = (1,1) ) + x = yolo_ConvBlock ( x, 1*num_filters, filter_size=(1,1), strides = (1,1) ) + + y = yolo_ConvBlock ( x, 2*num_filters, filter_size=(3,3), strides = (1,1) ) + y = Conv2D ( filters=out_filters, kernel_size=(1,1), strides=(1,1), + padding='same', use_bias=True, kernel_regularizer=l2(5e-4) ) (y) + + return x, y + +'''====================================================================================================''' +'''COMPLETE MODEL''' + +def yolo_body (input_tensor, num_out_filters): + ''' + Input: + input_tensor = Input( shape=( *input_shape, 3 ) ) + num_out_filter = ( num_anchors // 3 ) * ( 5 + num_classes ) + Output: + complete YOLO-v3 model + ''' + + # 1st Conv block + x = yolo_ConvBlock( input_tensor, num_filters=32, filter_size=(3,3), strides=(1,1) ) + + # 5 Resblocks + x = yolo_ResidualBlocks ( x, num_filters= 64, num_blocks=1 ) + x = yolo_ResidualBlocks ( x, num_filters= 128, num_blocks=2 ) + x = yolo_ResidualBlocks ( x, num_filters= 256, num_blocks=8 ) + x = yolo_ResidualBlocks ( x, num_filters= 512, num_blocks=8 ) + x = yolo_ResidualBlocks ( x, num_filters=1024, num_blocks=4 ) + + darknet = Model( input_tensor, x ) # will use it just in a moment + + # 1st output block + x, y1 = yolo_OutputBlock( x, num_filters= 512, out_filters=num_out_filters ) + + # 2nd output block + x = yolo_ConvBlock( x, num_filters=256, filter_size=(1,1), strides=(1,1) ) + x = UpSampling2D(2) (x) + x = Concatenate() ( [x, darknet.layers[152].output] ) + x, y2 = yolo_OutputBlock( x, num_filters= 256, out_filters=num_out_filters ) + + # 3rd output block + x = yolo_ConvBlock( x, num_filters=128, filter_size=(1,1), strides=(1,1) ) + x = UpSampling2D(2) (x) + x = Concatenate() ( [x, darknet.layers[92].output] ) + x, y3 = yolo_OutputBlock( x, num_filters= 128, out_filters=num_out_filters ) + + # Final model + model = Model( input_tensor, [y1, y2, y3] ) + + return model diff --git a/train_caption.py b/train_caption.py new file mode 100644 index 0000000000000000000000000000000000000000..7c639ac646b9a1b8074b6e9c2343b961de76db05 --- /dev/null +++ b/train_caption.py @@ -0,0 +1,206 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li +''' +import argparse +import os +import ruamel_yaml as yaml +import numpy as np +import random +import time +import datetime +import json +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +import torch.distributed as dist +from torch.utils.data import DataLoader + +from models.blip import blip_decoder +import utils +from utils import cosine_lr_schedule +from data import create_dataset, create_sampler, create_loader +from data.utils import save_result, coco_caption_eval + +def train(model, data_loader, optimizer, epoch, device): + # train + model.train() + + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) + header = 'Train Caption Epoch: [{}]'.format(epoch) + print_freq = 50 + + for i, (image, caption, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): + image = image.to(device) + + loss = model(image, caption) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + metric_logger.update(loss=loss.item()) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger.global_avg()) + return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluate(model, data_loader, device, config): + # evaluate + model.eval() + + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Caption generation:' + print_freq = 10 + + result = [] + for image, image_id in metric_logger.log_every(data_loader, print_freq, header): + + image = image.to(device) + + captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'], + min_length=config['min_length']) + + for caption, img_id in zip(captions, image_id): + result.append({"image_id": img_id.item(), "caption": caption}) + + return result + + +def main(args, config): + utils.init_distributed_mode(args) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + cudnn.benchmark = True + + #### Dataset #### + print("Creating captioning dataset") + train_dataset, val_dataset, test_dataset = create_dataset('caption_coco', config) + + if args.distributed: + num_tasks = utils.get_world_size() + global_rank = utils.get_rank() + samplers = create_sampler([train_dataset,val_dataset,test_dataset], [True,False,False], num_tasks, global_rank) + else: + samplers = [None, None, None] + + train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers, + batch_size=[config['batch_size']]*3,num_workers=[4,4,4], + is_trains=[True, False, False], collate_fns=[None,None,None]) + + #### Model #### + print("Creating model") + model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'], + vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'], + prompt=config['prompt']) + + model = model.to(device) + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) + + best = 0 + best_epoch = 0 + + print("Start training") + start_time = time.time() + for epoch in range(0, config['max_epoch']): + if not args.evaluate: + if args.distributed: + train_loader.sampler.set_epoch(epoch) + + cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) + + train_stats = train(model, train_loader, optimizer, epoch, device) + + val_result = evaluate(model_without_ddp, val_loader, device, config) + val_result_file = save_result(val_result, args.result_dir, 'val_epoch%d'%epoch, remove_duplicate='image_id') + + test_result = evaluate(model_without_ddp, test_loader, device, config) + test_result_file = save_result(test_result, args.result_dir, 'test_epoch%d'%epoch, remove_duplicate='image_id') + + if utils.is_main_process(): + coco_val = coco_caption_eval(config['coco_gt_root'],val_result_file,'val') + coco_test = coco_caption_eval(config['coco_gt_root'],test_result_file,'test') + + if args.evaluate: + log_stats = {**{f'val_{k}': v for k, v in coco_val.eval.items()}, + **{f'test_{k}': v for k, v in coco_test.eval.items()}, + } + with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f: + f.write(json.dumps(log_stats) + "\n") + else: + save_obj = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'config': config, + 'epoch': epoch, + } + + if coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] > best: + best = coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] + best_epoch = epoch + torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth')) + + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + **{f'val_{k}': v for k, v in coco_val.eval.items()}, + **{f'test_{k}': v for k, v in coco_test.eval.items()}, + 'epoch': epoch, + 'best_epoch': best_epoch, + } + with open(os.path.join(args.output_dir, "log.txt"),"a") as f: + f.write(json.dumps(log_stats) + "\n") + + if args.evaluate: + break + dist.barrier() + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', default='./configs/caption_coco.yaml') + parser.add_argument('--output_dir', default='output/Caption_coco') + parser.add_argument('--evaluate', action='store_true') + parser.add_argument('--device', default='cuda') + parser.add_argument('--seed', default=42, type=int) + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + parser.add_argument('--distributed', default=True, type=bool) + args = parser.parse_args() + + config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) + + args.result_dir = os.path.join(args.output_dir, 'result') + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + Path(args.result_dir).mkdir(parents=True, exist_ok=True) + + yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) + + main(args, config) \ No newline at end of file diff --git a/train_nlvr.py b/train_nlvr.py new file mode 100644 index 0000000000000000000000000000000000000000..84b247bda2334c1fd894b6c11d33ef48c8e7df28 --- /dev/null +++ b/train_nlvr.py @@ -0,0 +1,213 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li +''' +import argparse +import os +import ruamel_yaml as yaml +import numpy as np +import random +import time +import datetime +import json +from pathlib import Path +import json +import pickle + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +import torch.backends.cudnn as cudnn +import torch.distributed as dist + +from models.blip_nlvr import blip_nlvr + +import utils +from utils import cosine_lr_schedule, warmup_lr_schedule +from data import create_dataset, create_sampler, create_loader + +def train(model, data_loader, optimizer, epoch, device, config): + # train + model.train() + + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) + metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) + + header = 'Train Epoch: [{}]'.format(epoch) + print_freq = 50 + step_size = 10 + + for i,(image0, image1, text, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): + + images = torch.cat([image0, image1], dim=0) + images, targets = images.to(device), targets.to(device) + + loss = model(images, text, targets=targets, train=True) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + metric_logger.update(loss=loss.item()) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger.global_avg()) + return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluate(model, data_loader, device, config): + # test + model.eval() + + metric_logger = utils.MetricLogger(delimiter=" ") + + header = 'Evaluation:' + print_freq = 50 + + for image0, image1, text, targets in metric_logger.log_every(data_loader, print_freq, header): + images = torch.cat([image0, image1], dim=0) + images, targets = images.to(device), targets.to(device) + + prediction = model(images, text, targets=targets, train=False) + + _, pred_class = prediction.max(1) + accuracy = (targets==pred_class).sum() / targets.size(0) + + metric_logger.meters['acc'].update(accuracy.item(), n=image0.size(0)) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + + print("Averaged stats:", metric_logger.global_avg()) + return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} + + + +def main(args, config): + utils.init_distributed_mode(args) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + cudnn.benchmark = True + + #### Dataset #### + print("Creating dataset") + datasets = create_dataset('nlvr', config) + + if args.distributed: + num_tasks = utils.get_world_size() + global_rank = utils.get_rank() + samplers = create_sampler(datasets, [True,False,False], num_tasks, global_rank) + else: + samplers = [None, None, None] + + batch_size=[config['batch_size_train'],config['batch_size_test'],config['batch_size_test']] + train_loader, val_loader, test_loader = create_loader(datasets,samplers,batch_size=batch_size, + num_workers=[4,4,4],is_trains=[True,False,False], + collate_fns=[None,None,None]) + + #### Model #### + print("Creating model") + model = blip_nlvr(pretrained=config['pretrained'], image_size=config['image_size'], + vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer']) + + model = model.to(device) + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) + + print("Start training") + start_time = time.time() + best = 0 + best_epoch = 0 + + for epoch in range(0, config['max_epoch']): + if not args.evaluate: + if args.distributed: + train_loader.sampler.set_epoch(epoch) + + cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) + + train_stats = train(model, train_loader, optimizer, epoch, device, config) + + val_stats = evaluate(model, val_loader, device, config) + test_stats = evaluate(model, test_loader, device, config) + + if utils.is_main_process(): + if args.evaluate: + log_stats = {**{f'val_{k}': v for k, v in val_stats.items()}, + **{f'test_{k}': v for k, v in test_stats.items()}, + } + with open(os.path.join(args.output_dir, "log.txt"),"a") as f: + f.write(json.dumps(log_stats) + "\n") + + else: + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + **{f'val_{k}': v for k, v in val_stats.items()}, + **{f'test_{k}': v for k, v in test_stats.items()}, + 'epoch': epoch, + } + + if float(val_stats['acc'])>best: + save_obj = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'config': config, + 'epoch': epoch, + } + torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth')) + best = float(val_stats['acc']) + best_epoch = epoch + + with open(os.path.join(args.output_dir, "log.txt"),"a") as f: + f.write(json.dumps(log_stats) + "\n") + if args.evaluate: + break + + dist.barrier() + + if utils.is_main_process(): + with open(os.path.join(args.output_dir, "log.txt"),"a") as f: + f.write("best epoch: %d"%best_epoch) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', default='./configs/nlvr.yaml') + parser.add_argument('--output_dir', default='output/NLVR') + parser.add_argument('--evaluate', action='store_true') + parser.add_argument('--device', default='cuda') + parser.add_argument('--seed', default=42, type=int) + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + parser.add_argument('--distributed', default=True, type=bool) + args = parser.parse_args() + + config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) + + main(args, config) \ No newline at end of file diff --git a/train_retrieval.py b/train_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..574f03382cc8197b97971a11ae54b632bcfe6655 --- /dev/null +++ b/train_retrieval.py @@ -0,0 +1,345 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li +''' +import argparse +import os +import ruamel_yaml as yaml +import numpy as np +import random +import time +import datetime +import json +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +import torch.distributed as dist +from torch.utils.data import DataLoader + +from models.blip_retrieval import blip_retrieval +import utils +from utils import cosine_lr_schedule +from data import create_dataset, create_sampler, create_loader + + +def train(model, data_loader, optimizer, epoch, device, config): + # train + model.train() + + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) + metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) + header = 'Train Epoch: [{}]'.format(epoch) + print_freq = 50 + + for i,(image, caption, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): + image = image.to(device,non_blocking=True) + idx = idx.to(device,non_blocking=True) + + if epoch>0: + alpha = config['alpha'] + else: + alpha = config['alpha']*min(1,i/len(data_loader)) + + loss_ita, loss_itm = model(image, caption, alpha=alpha, idx=idx) + loss = loss_ita + loss_itm + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + metric_logger.update(loss_itm=loss_itm.item()) + metric_logger.update(loss_ita=loss_ita.item()) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger.global_avg()) + return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluation(model, data_loader, device, config): + # test + model.eval() + + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Evaluation:' + + print('Computing features for evaluation...') + start_time = time.time() + + texts = data_loader.dataset.text + num_text = len(texts) + text_bs = 256 + text_ids = [] + text_embeds = [] + text_atts = [] + for i in range(0, num_text, text_bs): + text = texts[i: min(num_text, i+text_bs)] + text_input = model.tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device) + text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text') + text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:])) + text_embeds.append(text_embed) + text_ids.append(text_input.input_ids) + text_atts.append(text_input.attention_mask) + + text_embeds = torch.cat(text_embeds,dim=0) + text_ids = torch.cat(text_ids,dim=0) + text_atts = torch.cat(text_atts,dim=0) + text_ids[:,0] = model.tokenizer.enc_token_id + + image_feats = [] + image_embeds = [] + for image, img_id in data_loader: + image = image.to(device) + image_feat = model.visual_encoder(image) + image_embed = model.vision_proj(image_feat[:,0,:]) + image_embed = F.normalize(image_embed,dim=-1) + + image_feats.append(image_feat.cpu()) + image_embeds.append(image_embed) + + image_feats = torch.cat(image_feats,dim=0) + image_embeds = torch.cat(image_embeds,dim=0) + + sims_matrix = image_embeds @ text_embeds.t() + score_matrix_i2t = torch.full((len(data_loader.dataset.image),len(texts)),-100.0).to(device) + + num_tasks = utils.get_world_size() + rank = utils.get_rank() + step = sims_matrix.size(0)//num_tasks + 1 + start = rank*step + end = min(sims_matrix.size(0),start+step) + + for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): + topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0) + + encoder_output = image_feats[start+i].repeat(config['k_test'],1,1).to(device) + encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device) + output = model.text_encoder(text_ids[topk_idx], + attention_mask = text_atts[topk_idx], + encoder_hidden_states = encoder_output, + encoder_attention_mask = encoder_att, + return_dict = True, + ) + score = model.itm_head(output.last_hidden_state[:,0,:])[:,1] + score_matrix_i2t[start+i,topk_idx] = score + topk_sim + + sims_matrix = sims_matrix.t() + score_matrix_t2i = torch.full((len(texts),len(data_loader.dataset.image)),-100.0).to(device) + + step = sims_matrix.size(0)//num_tasks + 1 + start = rank*step + end = min(sims_matrix.size(0),start+step) + + for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): + + topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0) + encoder_output = image_feats[topk_idx].to(device) + encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device) + output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1), + attention_mask = text_atts[start+i].repeat(config['k_test'],1), + encoder_hidden_states = encoder_output, + encoder_attention_mask = encoder_att, + return_dict = True, + ) + score = model.itm_head(output.last_hidden_state[:,0,:])[:,1] + score_matrix_t2i[start+i,topk_idx] = score + topk_sim + + if args.distributed: + dist.barrier() + torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Evaluation time {}'.format(total_time_str)) + + return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() + + + +@torch.no_grad() +def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt): + + #Images->Text + ranks = np.zeros(scores_i2t.shape[0]) + for index,score in enumerate(scores_i2t): + inds = np.argsort(score)[::-1] + # Score + rank = 1e20 + for i in img2txt[index]: + tmp = np.where(inds == i)[0][0] + if tmp < rank: + rank = tmp + ranks[index] = rank + + # Compute metrics + tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) + tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) + tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) + + #Text->Images + ranks = np.zeros(scores_t2i.shape[0]) + + for index,score in enumerate(scores_t2i): + inds = np.argsort(score)[::-1] + ranks[index] = np.where(inds == txt2img[index])[0][0] + + # Compute metrics + ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) + ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) + ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) + + tr_mean = (tr1 + tr5 + tr10) / 3 + ir_mean = (ir1 + ir5 + ir10) / 3 + r_mean = (tr_mean + ir_mean) / 2 + + eval_result = {'txt_r1': tr1, + 'txt_r5': tr5, + 'txt_r10': tr10, + 'txt_r_mean': tr_mean, + 'img_r1': ir1, + 'img_r5': ir5, + 'img_r10': ir10, + 'img_r_mean': ir_mean, + 'r_mean': r_mean} + return eval_result + + +def main(args, config): + utils.init_distributed_mode(args) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + cudnn.benchmark = True + + #### Dataset #### + print("Creating retrieval dataset") + train_dataset, val_dataset, test_dataset = create_dataset('retrieval_%s'%config['dataset'], config) + + if args.distributed: + num_tasks = utils.get_world_size() + global_rank = utils.get_rank() + samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None] + else: + samplers = [None, None, None] + + train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers, + batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2, + num_workers=[4,4,4], + is_trains=[True, False, False], + collate_fns=[None,None,None]) + + + #### Model #### + print("Creating model") + model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'], + vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'], + queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank']) + + model = model.to(device) + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) + + best = 0 + best_epoch = 0 + + print("Start training") + start_time = time.time() + + for epoch in range(0, config['max_epoch']): + if not args.evaluate: + if args.distributed: + train_loader.sampler.set_epoch(epoch) + + cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) + + train_stats = train(model, train_loader, optimizer, epoch, device, config) + + score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, device, config) + score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, device, config) + + if utils.is_main_process(): + + val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt) + print(val_result) + + if val_result['r_mean']>best: + save_obj = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'config': config, + 'epoch': epoch, + } + torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth')) + best = val_result['r_mean'] + best_epoch = epoch + + test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt) + print(test_result) + + if args.evaluate: + log_stats = {**{f'val_{k}': v for k, v in val_result.items()}, + **{f'test_{k}': v for k, v in test_result.items()}, + } + with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f: + f.write(json.dumps(log_stats) + "\n") + else: + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + **{f'val_{k}': v for k, v in val_result.items()}, + **{f'test_{k}': v for k, v in test_result.items()}, + 'epoch': epoch, + 'best_epoch': best_epoch, + } + with open(os.path.join(args.output_dir, "log.txt"),"a") as f: + f.write(json.dumps(log_stats) + "\n") + + if args.evaluate: + break + + dist.barrier() + torch.cuda.empty_cache() + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', default='./configs/retrieval_flickr.yaml') + parser.add_argument('--output_dir', default='output/Retrieval_flickr') + parser.add_argument('--evaluate', action='store_true') + parser.add_argument('--device', default='cuda') + parser.add_argument('--seed', default=42, type=int) + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + parser.add_argument('--distributed', default=True, type=bool) + args = parser.parse_args() + + config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) + + main(args, config) \ No newline at end of file diff --git a/train_vqa.py b/train_vqa.py new file mode 100644 index 0000000000000000000000000000000000000000..89eb7490862e517cc660f842396033c21d441a20 --- /dev/null +++ b/train_vqa.py @@ -0,0 +1,202 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li +''' +import argparse +import os +import ruamel_yaml as yaml +import numpy as np +import random +import time +import datetime +import json +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +import torch.backends.cudnn as cudnn +import torch.distributed as dist + +from models.blip_vqa import blip_vqa +import utils +from utils import cosine_lr_schedule +from data import create_dataset, create_sampler, create_loader +from data.vqa_dataset import vqa_collate_fn +from data.utils import save_result + + +def train(model, data_loader, optimizer, epoch, device): + # train + model.train() + + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}')) + + header = 'Train Epoch: [{}]'.format(epoch) + print_freq = 50 + + for i,(image, question, answer, weights, n) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): + image, weights = image.to(device,non_blocking=True), weights.to(device,non_blocking=True) + + loss = model(image, question, answer, train=True, n=n, weights=weights) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + metric_logger.update(loss=loss.item()) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger.global_avg()) + return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} + + +@torch.no_grad() +def evaluation(model, data_loader, device, config) : + # test + model.eval() + + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Generate VQA test result:' + print_freq = 50 + + result = [] + + if config['inference']=='rank': + answer_list = data_loader.dataset.answer_list + answer_candidates = model.tokenizer(answer_list, padding='longest', return_tensors='pt').to(device) + answer_candidates.input_ids[:,0] = model.tokenizer.bos_token_id + + for n, (image, question, question_id) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): + image = image.to(device,non_blocking=True) + + if config['inference']=='generate': + answers = model(image, question, train=False, inference='generate') + + for answer, ques_id in zip(answers, question_id): + ques_id = int(ques_id.item()) + result.append({"question_id":ques_id, "answer":answer}) + + elif config['inference']=='rank': + answer_ids = model(image, question, answer_candidates, train=False, inference='rank', k_test=config['k_test']) + + for ques_id, answer_id in zip(question_id, answer_ids): + result.append({"question_id":int(ques_id.item()), "answer":answer_list[answer_id]}) + + return result + + +def main(args, config): + utils.init_distributed_mode(args) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + cudnn.benchmark = True + + #### Dataset #### + print("Creating vqa datasets") + datasets = create_dataset('vqa', config) + + if args.distributed: + num_tasks = utils.get_world_size() + global_rank = utils.get_rank() + samplers = create_sampler(datasets, [True, False], num_tasks, global_rank) + else: + samplers = [None, None] + + train_loader, test_loader = create_loader(datasets,samplers, + batch_size=[config['batch_size_train'],config['batch_size_test']], + num_workers=[4,4],is_trains=[True, False], + collate_fns=[vqa_collate_fn,None]) + #### Model #### + print("Creating model") + model = blip_vqa(pretrained=config['pretrained'], image_size=config['image_size'], + vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer']) + + model = model.to(device) + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay']) + + best = 0 + best_epoch = 0 + + print("Start training") + start_time = time.time() + for epoch in range(0, config['max_epoch']): + if not args.evaluate: + if args.distributed: + train_loader.sampler.set_epoch(epoch) + + cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) + + train_stats = train(model, train_loader, optimizer, epoch, device) + + else: + break + + if utils.is_main_process(): + log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, + 'epoch': epoch, + } + with open(os.path.join(args.output_dir, "log.txt"),"a") as f: + f.write(json.dumps(log_stats) + "\n") + + save_obj = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'config': config, + 'epoch': epoch, + } + torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch)) + + dist.barrier() + + vqa_result = evaluation(model_without_ddp, test_loader, device, config) + result_file = save_result(vqa_result, args.result_dir, 'vqa_result') + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', default='./configs/vqa.yaml') + parser.add_argument('--output_dir', default='output/VQA') + parser.add_argument('--evaluate', action='store_true') + parser.add_argument('--device', default='cuda') + parser.add_argument('--seed', default=42, type=int) + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + parser.add_argument('--distributed', default=True, type=bool) + args = parser.parse_args() + + config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) + + args.result_dir = os.path.join(args.output_dir, 'result') + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + Path(args.result_dir).mkdir(parents=True, exist_ok=True) + + yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) + + main(args, config) \ No newline at end of file diff --git a/transform/randaugment.py b/transform/randaugment.py new file mode 100644 index 0000000000000000000000000000000000000000..094d9f4cacc93146d2bab7311d9dc04feb07032c --- /dev/null +++ b/transform/randaugment.py @@ -0,0 +1,340 @@ +import cv2 +import numpy as np + + +## aug functions +def identity_func(img): + return img + + +def autocontrast_func(img, cutoff=0): + ''' + same output as PIL.ImageOps.autocontrast + ''' + n_bins = 256 + + def tune_channel(ch): + n = ch.size + cut = cutoff * n // 100 + if cut == 0: + high, low = ch.max(), ch.min() + else: + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + low = np.argwhere(np.cumsum(hist) > cut) + low = 0 if low.shape[0] == 0 else low[0] + high = np.argwhere(np.cumsum(hist[::-1]) > cut) + high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] + if high <= low: + table = np.arange(n_bins) + else: + scale = (n_bins - 1) / (high - low) + offset = -low * scale + table = np.arange(n_bins) * scale + offset + table[table < 0] = 0 + table[table > n_bins - 1] = n_bins - 1 + table = table.clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def equalize_func(img): + ''' + same output as PIL.ImageOps.equalize + PIL's implementation is different from cv2.equalize + ''' + n_bins = 256 + + def tune_channel(ch): + hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) + non_zero_hist = hist[hist != 0].reshape(-1) + step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) + if step == 0: return ch + n = np.empty_like(hist) + n[0] = step // 2 + n[1:] = hist[:-1] + table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) + return table[ch] + + channels = [tune_channel(ch) for ch in cv2.split(img)] + out = cv2.merge(channels) + return out + + +def rotate_func(img, degree, fill=(0, 0, 0)): + ''' + like PIL, rotate by degree, not radians + ''' + H, W = img.shape[0], img.shape[1] + center = W / 2, H / 2 + M = cv2.getRotationMatrix2D(center, degree, 1) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill) + return out + + +def solarize_func(img, thresh=128): + ''' + same output as PIL.ImageOps.posterize + ''' + table = np.array([el if el < thresh else 255 - el for el in range(256)]) + table = table.clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def color_func(img, factor): + ''' + same output as PIL.ImageEnhance.Color + ''' + ## implementation according to PIL definition, quite slow + # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] + # out = blend(degenerate, img, factor) + # M = ( + # np.eye(3) * factor + # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) + # )[np.newaxis, np.newaxis, :] + M = ( + np.float32([ + [0.886, -0.114, -0.114], + [-0.587, 0.413, -0.587], + [-0.299, -0.299, 0.701]]) * factor + + np.float32([[0.114], [0.587], [0.299]]) + ) + out = np.matmul(img, M).clip(0, 255).astype(np.uint8) + return out + + +def contrast_func(img, factor): + """ + same output as PIL.ImageEnhance.Contrast + """ + mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) + table = np.array([( + el - mean) * factor + mean + for el in range(256) + ]).clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def brightness_func(img, factor): + ''' + same output as PIL.ImageEnhance.Contrast + ''' + table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) + out = table[img] + return out + + +def sharpness_func(img, factor): + ''' + The differences the this result and PIL are all on the 4 boundaries, the center + areas are same + ''' + kernel = np.ones((3, 3), dtype=np.float32) + kernel[1][1] = 5 + kernel /= 13 + degenerate = cv2.filter2D(img, -1, kernel) + if factor == 0.0: + out = degenerate + elif factor == 1.0: + out = img + else: + out = img.astype(np.float32) + degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] + out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) + out = out.astype(np.uint8) + return out + + +def shear_x_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, factor, 0], [0, 1, 0]]) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def translate_x_func(img, offset, fill=(0, 0, 0)): + ''' + same output as PIL.Image.transform + ''' + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, -offset], [0, 1, 0]]) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def translate_y_func(img, offset, fill=(0, 0, 0)): + ''' + same output as PIL.Image.transform + ''' + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [0, 1, -offset]]) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def posterize_func(img, bits): + ''' + same output as PIL.ImageOps.posterize + ''' + out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) + return out + + +def shear_y_func(img, factor, fill=(0, 0, 0)): + H, W = img.shape[0], img.shape[1] + M = np.float32([[1, 0, 0], [factor, 1, 0]]) + out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) + return out + + +def cutout_func(img, pad_size, replace=(0, 0, 0)): + replace = np.array(replace, dtype=np.uint8) + H, W = img.shape[0], img.shape[1] + rh, rw = np.random.random(2) + pad_size = pad_size // 2 + ch, cw = int(rh * H), int(rw * W) + x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) + y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) + out = img.copy() + out[x1:x2, y1:y2, :] = replace + return out + + +### level to args +def enhance_level_to_args(MAX_LEVEL): + def level_to_args(level): + return ((level / MAX_LEVEL) * 1.8 + 0.1,) + return level_to_args + + +def shear_level_to_args(MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * 0.3 + if np.random.random() > 0.5: level = -level + return (level, replace_value) + + return level_to_args + + +def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * float(translate_const) + if np.random.random() > 0.5: level = -level + return (level, replace_value) + + return level_to_args + + +def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): + def level_to_args(level): + level = int((level / MAX_LEVEL) * cutout_const) + return (level, replace_value) + + return level_to_args + + +def solarize_level_to_args(MAX_LEVEL): + def level_to_args(level): + level = int((level / MAX_LEVEL) * 256) + return (level, ) + return level_to_args + + +def none_level_to_args(level): + return () + + +def posterize_level_to_args(MAX_LEVEL): + def level_to_args(level): + level = int((level / MAX_LEVEL) * 4) + return (level, ) + return level_to_args + + +def rotate_level_to_args(MAX_LEVEL, replace_value): + def level_to_args(level): + level = (level / MAX_LEVEL) * 30 + if np.random.random() < 0.5: + level = -level + return (level, replace_value) + + return level_to_args + + +func_dict = { + 'Identity': identity_func, + 'AutoContrast': autocontrast_func, + 'Equalize': equalize_func, + 'Rotate': rotate_func, + 'Solarize': solarize_func, + 'Color': color_func, + 'Contrast': contrast_func, + 'Brightness': brightness_func, + 'Sharpness': sharpness_func, + 'ShearX': shear_x_func, + 'TranslateX': translate_x_func, + 'TranslateY': translate_y_func, + 'Posterize': posterize_func, + 'ShearY': shear_y_func, +} + +translate_const = 10 +MAX_LEVEL = 10 +replace_value = (128, 128, 128) +arg_dict = { + 'Identity': none_level_to_args, + 'AutoContrast': none_level_to_args, + 'Equalize': none_level_to_args, + 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), + 'Solarize': solarize_level_to_args(MAX_LEVEL), + 'Color': enhance_level_to_args(MAX_LEVEL), + 'Contrast': enhance_level_to_args(MAX_LEVEL), + 'Brightness': enhance_level_to_args(MAX_LEVEL), + 'Sharpness': enhance_level_to_args(MAX_LEVEL), + 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), + 'TranslateX': translate_level_to_args( + translate_const, MAX_LEVEL, replace_value + ), + 'TranslateY': translate_level_to_args( + translate_const, MAX_LEVEL, replace_value + ), + 'Posterize': posterize_level_to_args(MAX_LEVEL), + 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), +} + + +class RandomAugment(object): + + def __init__(self, N=2, M=10, isPIL=False, augs=[]): + self.N = N + self.M = M + self.isPIL = isPIL + if augs: + self.augs = augs + else: + self.augs = list(arg_dict.keys()) + + def get_random_ops(self): + sampled_ops = np.random.choice(self.augs, self.N) + return [(op, 0.5, self.M) for op in sampled_ops] + + def __call__(self, img): + if self.isPIL: + img = np.array(img) + ops = self.get_random_ops() + for name, prob, level in ops: + if np.random.random() > prob: + continue + args = arg_dict[name](level) + img = func_dict[name](img, *args) + return img + + +if __name__ == '__main__': + a = RandomAugment() + img = np.random.randn(32, 32, 3) + a(img) \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ebe0e1dc2f5d200156d5dd1acc305a8b7b7b98da --- /dev/null +++ b/utils.py @@ -0,0 +1,278 @@ +import math +def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): + """Decay the learning rate""" + lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): + """Warmup the learning rate""" + lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): + """Decay the learning rate""" + lr = max(min_lr, init_lr * (decay_rate**epoch)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +import numpy as np +import io +import os +import time +from collections import defaultdict, deque +import datetime + +import torch +import torch.distributed as dist + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def global_avg(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {:.4f}".format(name, meter.global_avg) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def compute_acc(logits, label, reduction='mean'): + ret = (torch.argmax(logits, dim=1) == label).float() + if reduction == 'none': + return ret.detach() + elif reduction == 'mean': + return ret.mean().item() + +def compute_n_params(model, return_str=True): + tot = 0 + for p in model.parameters(): + w = 1 + for x in p.shape: + w *= x + tot += w + if return_str: + if tot >= 1e6: + return '{:.1f}M'.format(tot / 1e6) + else: + return '{:.1f}K'.format(tot / 1e3) + else: + return tot + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}, word {}): {}'.format( + args.rank, args.world_size, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + \ No newline at end of file