zamborg commited on
Commit
5281471
1 Parent(s): 49c0315

updated things

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +52 -91
  3. virtex/requirements.txt +0 -18
.gitignore CHANGED
@@ -2,3 +2,4 @@
2
  *.pth
3
  *.yaml
4
  *ipynb_checkpoints
 
 
2
  *.pth
3
  *.yaml
4
  *ipynb_checkpoints
5
+ __pycache__
app.py CHANGED
@@ -1,98 +1,59 @@
1
  import streamlit as st
2
- from huggingface_hub import snapshot_download
3
- from PIL import Image
4
-
5
- import argparse
6
- import json
7
- import os
8
- from typing import Any, Dict, List
9
-
10
- from loguru import logger
11
- import torch
12
- import torchvision
13
- from torch.utils.data import DataLoader
14
- from tqdm import tqdm
15
-
16
- import wordsegment as ws
17
-
18
- from virtex.config import Config
19
- from virtex.data import ImageDirectoryDataset
20
- from virtex.factories import TokenizerFactory, PretrainingModelFactory
21
- from virtex.utils.checkpointing import CheckpointManager
22
- from virtex.utils.common import common_parser
23
-
24
- CONFIG_PATH = "config.yaml"
25
- MODEL_PATH = "checkpoint_last5.pth"
26
 
27
  # x = st.slider("Select a value")
28
  # st.write(x, "squared is", x * x)
29
 
30
-
31
-
32
- class ImageLoader():
33
- def __init__(self):
34
- self.transformer = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
35
- torchvision.transforms.CenterCrop(224),
36
- torchvision.transforms.ToTensor()])
37
- def load(self, im_path, prompt):
38
- im = torch.FloatTensor(self.transformer(Image.open(im_path))).unsqueeze(0)
39
- return {"image": im, "decode_prompt": prompt}
40
-
41
- class VirTexModel():
42
- def __init__(self):
43
- self.config = Config(CONFIG_PATH)
44
- ws.load()
45
- self.device = 'cpu'
46
- self.tokenizer = TokenizerFactory.from_config(self.config)
47
- self.model = PretrainingModelFactory.from_config(self.config).to(self.device)
48
- CheckpointManager(model=self.model).load("./checkpoint_last5.pth")
49
- self.model.eval()
50
- self.loader = ImageLoader()
51
-
52
- def predict(self, im_path):
53
- subreddit_tokens = torch.tensor([self.model.sos_index], device=self.device).long()
54
- predictions: List[Dict[str, Any]] = []
55
- image = self.loader.load(im_path, subreddit_tokens) # should be of shape 1, 3, 224, 224
56
- output_dict = self.model(image)
57
- caption = output_dict["predictions"][0] #only one prediction
58
- caption = caption.tolist()
59
- if self.tokenizer.token_to_id("[SEP]") in caption: # this is just the 0 index actually
60
- sos_index = caption.index(self.tokenizer.token_to_id("[SEP]"))
61
- caption[sos_index] = self.tokenizer.token_to_id("::")
62
-
63
- caption = self.tokenizer.decode(caption)
64
-
65
- # Separate out subreddit from the rest of caption.
66
- if "⁇" in caption: # "⁇" is the token decode equivalent of "::"
67
- subreddit, rest_of_caption = caption.split("⁇")
68
- subreddit = "".join(subreddit.split())
69
- rest_of_caption = rest_of_caption.strip()
70
- else:
71
- subreddit, rest_of_caption = "", caption
72
-
73
- return subreddit, rest_of_caption
74
-
75
- def load_models():
76
- #download model files
77
- download_files = [CONFIG_PATH, MODEL_PATH]
78
- for f in download_files:
79
- fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f))
80
- os.system(f"cp {fp} ./{f}")
81
-
82
 
83
-
84
- # load a virtex model
85
- from huggingface_hub import hf_hub_url, cached_download
86
-
87
- # #download model files
88
- download_files = [CONFIG_PATH, MODEL_PATH]
89
- for f in download_files:
90
- fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f))
91
- os.system(f"cp {fp} ./{f}")
92
-
93
- #inference on test.jpg
94
- virtexModel = VirTexModel()
95
- subreddit, caption = virtexModel.predict("./test.jpg")
96
- print(subreddit)
97
- print(caption)
 
 
 
 
 
 
 
 
 
 
 
98
 
 
1
  import streamlit as st
2
+ import io
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  # x = st.slider("Select a value")
5
  # st.write(x, "squared is", x * x)
6
 
7
+ st.title("Image Captioning Demo from Redcaps")
8
+ st.sidebar.markdown(
9
+ """
10
+ Image Captioning Model from VirTex trained on Redcaps
11
+ """
12
+ )
13
+
14
+ with st.spinner("Loading Model"):
15
+ from model import *
16
+ sample_images = glob.glob("./samples/*.jpg")
17
+ download_files()
18
+ virtexModel = VirTexModel()
19
+ imageLoader = ImageLoader()
20
+
21
+ random_image = get_rand_img(sample_images)
22
+
23
+ st.sidebar.title("Select a sample image")
24
+ sample_image = st.sidebar.selectbox(
25
+ "",
26
+ sample_images
27
+ )
28
+
29
+ if st.sidebar.button("Random Sample Image"):
30
+ random_image = get_rand_img(sample_images)
31
+ sample_image = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ uploaded_image = None
34
+ with st.sidebar.form("file-uploader-form", clear_on_submit=True):
35
+ uploaded_file = st.file_uploader("Choose a file")
36
+ submitted = st.form_submit_button("Submit")
37
+ if uploaded_file is not None and submitted:
38
+ uploaded_image = Image.open(io.BytesIO(uploaded_file.get_values()))
39
+
40
+ if uploaded_image is None and submitted:
41
+ st.write("Please select a file to upload")
42
+
43
+ else:
44
+ image_file = sample_image if sample_image is not None else random_image
45
+
46
+ image = uploaded_image if uploaded_image is not None else Image.open()
47
+
48
+ image_dict = imageLoader.transform(image)
49
+
50
+ show.image(st.image(image_dict["image"]), "Target Image")
51
+
52
+ with st.spinner("Generating Caption"):
53
+ subreddit, caption = virtexModel.predict(image_dict)
54
+ st.header("Predicted Caption:\n\n")
55
+ st.subheader(f"Subreddit: {subreddit}\n")
56
+ st.subheader(f"Caption: {caption}\n")
57
+
58
+ image.close()
59
 
virtex/requirements.txt DELETED
@@ -1,18 +0,0 @@
1
- albumentations>=0.5.0
2
- Cython>=0.25
3
- ftfy==5.8
4
- future==0.18.0
5
- lmdb==0.97
6
- loguru==0.3.2
7
- mypy_extensions==0.4.1
8
- lvis==0.5.3
9
- numpy>=1.17
10
- opencv-python==4.1.2.30
11
- scikit-learn==0.21.3
12
- sentencepiece>=0.1.90
13
- torch==1.7.0
14
- torchvision==0.8
15
- tqdm>=4.50.0
16
- wordsegment==1.3.1
17
- git+git://github.com/facebookresearch/fvcore.git#egg=fvcore
18
- git+git://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI