kdexd commited on
Commit
8d0e872
β€’
1 Parent(s): 5650fb4

Black + isort, remove unused virtx files.

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. app.py +32 -25
  2. model.py +72 -58
  3. virtex/CHANGELOG.md +0 -41
  4. virtex/LICENSE +0 -16
  5. virtex/README.md +0 -92
  6. virtex/{virtex/__init__.py β†’ __init__.py} +0 -0
  7. virtex/{virtex/config.py β†’ config.py} +0 -0
  8. virtex/configs/_base_bicaptioning_R_50_L1_H1024.yaml +0 -66
  9. virtex/configs/backbone_ablations/bicaptioning_R_101_L1_H1024.yaml +0 -5
  10. virtex/configs/backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml +0 -5
  11. virtex/configs/backbone_ablations/bicaptioning_R_50_L1_H1024.yaml +0 -1
  12. virtex/configs/depth_ablations/bicaptioning_R_50_L1_H1024.yaml +0 -1
  13. virtex/configs/depth_ablations/bicaptioning_R_50_L2_H1024.yaml +0 -5
  14. virtex/configs/depth_ablations/bicaptioning_R_50_L3_H1024.yaml +0 -5
  15. virtex/configs/depth_ablations/bicaptioning_R_50_L4_H1024.yaml +0 -5
  16. virtex/configs/detectron2/_base_faster_rcnn_R_50_C4_BN.yaml +0 -49
  17. virtex/configs/detectron2/_base_mask_rcnn_R_50_FPN.yaml +0 -75
  18. virtex/configs/detectron2/coco_segm_default_init_2x.yaml +0 -24
  19. virtex/configs/detectron2/lvis_segm_default_init_2x.yaml +0 -36
  20. virtex/configs/detectron2/lvis_segm_imagenet_init_2x.yaml +0 -38
  21. virtex/configs/detectron2/voc_det_default_init_24k.yaml +0 -28
  22. virtex/configs/downstream/imagenet_clf.yaml +0 -33
  23. virtex/configs/downstream/inaturalist_clf.yaml +0 -36
  24. virtex/configs/downstream/voc07_clf.yaml +0 -15
  25. virtex/configs/redcaps/gcc_R_50_L6_H512.yaml +0 -35
  26. virtex/configs/redcaps/miniclip_sbu_R_50_L12_H512.yaml +0 -35
  27. virtex/configs/redcaps/redcaps_2020_R_50_L6_H512.yaml +0 -35
  28. virtex/configs/redcaps/redcaps_all_R_50_L6_H512.yaml +0 -35
  29. virtex/configs/redcaps/sbu_R_50_L6_H512.yaml +0 -35
  30. virtex/configs/task_ablations/bicaptioning_R_50_L1_H2048.yaml +0 -5
  31. virtex/configs/task_ablations/captioning_R_50_L1_H2048.yaml +0 -6
  32. virtex/configs/task_ablations/masked_lm_R_50_L1_H2048.yaml +0 -6
  33. virtex/configs/task_ablations/multilabel_classification_R_50.yaml +0 -12
  34. virtex/configs/task_ablations/token_classification_R_50.yaml +0 -9
  35. virtex/configs/width_ablations/bicaptioning_R_50_L1_H1024.yaml +0 -1
  36. virtex/configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml +0 -5
  37. virtex/configs/width_ablations/bicaptioning_R_50_L1_H512.yaml +0 -5
  38. virtex/configs/width_ablations/bicaptioning_R_50_L1_H768.yaml +0 -5
  39. virtex/{virtex/data β†’ data}/__init__.py +0 -0
  40. virtex/{virtex/data β†’ data}/datasets/captioning.py +0 -0
  41. virtex/{virtex/data β†’ data}/datasets/classification.py +0 -0
  42. virtex/{virtex/data β†’ data}/datasets/downstream.py +0 -0
  43. virtex/{virtex/data β†’ data}/datasets/masked_lm.py +0 -0
  44. virtex/{virtex/data β†’ data}/datasets/redcaps.py +0 -0
  45. virtex/{virtex/data β†’ data}/datasets/zero_shot.py +0 -0
  46. virtex/{virtex/data β†’ data}/readers.py +0 -0
  47. virtex/{virtex/data β†’ data}/tokenizers.py +0 -0
  48. virtex/{virtex/data β†’ data}/transforms.py +0 -0
  49. virtex/docs/Makefile +0 -19
  50. virtex/docs/_static/custom.css +0 -115
app.py CHANGED
@@ -1,18 +1,18 @@
1
- import streamlit as st
2
  import io
3
- import sys
4
- import time
5
- import json
6
- sys.path.append("./virtex/")
7
  from model import *
8
 
9
  # # TODO:
10
  # - Reformat the model introduction
11
  # - Make the iterative text generation
12
 
13
- def gen_show_caption(sub_prompt=None, cap_prompt = ""):
 
14
  with st.spinner("Generating Caption"):
15
- subreddit, caption = virtexModel.predict(image_dict, sub_prompt=sub_prompt, prompt=cap_prompt)
 
 
16
  st.markdown(
17
  f"""
18
  <style>
@@ -28,10 +28,12 @@ def gen_show_caption(sub_prompt=None, cap_prompt = ""):
28
  </style>
29
 
30
  ### <red> r/{subreddit} </red> <blue> {cap_prompt} </blue> {caption}
31
- """,
32
- unsafe_allow_html=True)
33
-
34
- _, center, _ = st.columns([1,8,1])
 
 
35
 
36
  with center:
37
  st.title("Image Captioning Demo from RedCaps")
@@ -50,7 +52,7 @@ st.sidebar.markdown(
50
 
51
  with st.spinner("Loading Model"):
52
  virtexModel, imageLoader, sample_images, valid_subs = create_objects()
53
-
54
 
55
  select_idx = None
56
 
@@ -66,9 +68,9 @@ uploaded_image = None
66
  # with st.sidebar.form("file-uploader-form", clear_on_submit=True):
67
  uploaded_file = st.sidebar.file_uploader("Choose a file")
68
  # submitted = st.form_submit_button("Submit")
69
- if uploaded_file is not None:# and submitted:
70
  uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
71
- select_idx = None # set this to help rewrite the cache
72
 
73
  # class OnChange():
74
  # def __init__(self, idx):
@@ -88,21 +90,26 @@ if uploaded_file is not None:# and submitted:
88
  st.sidebar.title("Select a Subreddit")
89
  sub = st.sidebar.selectbox(
90
  "Type below to condition on a subreddit. Select None for a predicted subreddit",
91
- valid_subs
92
  )
93
 
94
  st.sidebar.title("Write a Custom Prompt")
95
- cap_prompt = st.sidebar.text_input(
96
- "Write the start of your caption below",
97
- value=""
98
- )
99
 
100
  _ = st.sidebar.button("Regenerate Caption")
101
 
102
 
103
  st.sidebar.write("Advanced Options:")
104
- num_captions = st.sidebar.select_slider("Number of Captions to Predict", options=[1,2,3,4,5], value=1)
105
- nuc_size = st.sidebar.slider("Nucelus Size:\nLarger values lead to more diverse captions", min_value=0.0, max_value=1.0, value=0.8, step=0.05)
 
 
 
 
 
 
 
 
106
  virtexModel.model.decoder.nucleus_size = nuc_size
107
 
108
  image_file = sample_image
@@ -110,14 +117,14 @@ image_file = sample_image
110
  # LOAD AND CACHE THE IMAGE
111
  if uploaded_image is not None:
112
  image = uploaded_image
113
- elif select_idx is None and 'image' in st.session_state:
114
- image = st.session_state['image']
115
  else:
116
  image = Image.open(image_file)
117
 
118
  image = image.convert("RGB")
119
 
120
- st.session_state['image'] = image
121
 
122
 
123
  image_dict = imageLoader.transform(image)
@@ -141,4 +148,4 @@ This demo accompanies our paper RedCaps.
141
 
142
  Created by Karan Desai, Gaurav Kaul, Zubin Aysola, Justin Johnson
143
  """
144
- )
 
 
1
  import io
2
+
3
+ import streamlit as st
 
 
4
  from model import *
5
 
6
  # # TODO:
7
  # - Reformat the model introduction
8
  # - Make the iterative text generation
9
 
10
+
11
+ def gen_show_caption(sub_prompt=None, cap_prompt=""):
12
  with st.spinner("Generating Caption"):
13
+ subreddit, caption = virtexModel.predict(
14
+ image_dict, sub_prompt=sub_prompt, prompt=cap_prompt
15
+ )
16
  st.markdown(
17
  f"""
18
  <style>
 
28
  </style>
29
 
30
  ### <red> r/{subreddit} </red> <blue> {cap_prompt} </blue> {caption}
31
+ """,
32
+ unsafe_allow_html=True,
33
+ )
34
+
35
+
36
+ _, center, _ = st.columns([1, 8, 1])
37
 
38
  with center:
39
  st.title("Image Captioning Demo from RedCaps")
 
52
 
53
  with st.spinner("Loading Model"):
54
  virtexModel, imageLoader, sample_images, valid_subs = create_objects()
55
+
56
 
57
  select_idx = None
58
 
 
68
  # with st.sidebar.form("file-uploader-form", clear_on_submit=True):
69
  uploaded_file = st.sidebar.file_uploader("Choose a file")
70
  # submitted = st.form_submit_button("Submit")
71
+ if uploaded_file is not None: # and submitted:
72
  uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
73
+ select_idx = None # set this to help rewrite the cache
74
 
75
  # class OnChange():
76
  # def __init__(self, idx):
 
90
  st.sidebar.title("Select a Subreddit")
91
  sub = st.sidebar.selectbox(
92
  "Type below to condition on a subreddit. Select None for a predicted subreddit",
93
+ valid_subs,
94
  )
95
 
96
  st.sidebar.title("Write a Custom Prompt")
97
+ cap_prompt = st.sidebar.text_input("Write the start of your caption below", value="")
 
 
 
98
 
99
  _ = st.sidebar.button("Regenerate Caption")
100
 
101
 
102
  st.sidebar.write("Advanced Options:")
103
+ num_captions = st.sidebar.select_slider(
104
+ "Number of Captions to Predict", options=[1, 2, 3, 4, 5], value=1
105
+ )
106
+ nuc_size = st.sidebar.slider(
107
+ "Nucelus Size:\nLarger values lead to more diverse captions",
108
+ min_value=0.0,
109
+ max_value=1.0,
110
+ value=0.8,
111
+ step=0.05,
112
+ )
113
  virtexModel.model.decoder.nucleus_size = nuc_size
114
 
115
  image_file = sample_image
 
117
  # LOAD AND CACHE THE IMAGE
118
  if uploaded_image is not None:
119
  image = uploaded_image
120
+ elif select_idx is None and "image" in st.session_state:
121
+ image = st.session_state["image"]
122
  else:
123
  image = Image.open(image_file)
124
 
125
  image = image.convert("RGB")
126
 
127
+ st.session_state["image"] = image
128
 
129
 
130
  image_dict = imageLoader.transform(image)
 
148
 
149
  Created by Karan Desai, Gaurav Kaul, Zubin Aysola, Justin Johnson
150
  """
151
+ )
model.py CHANGED
@@ -1,18 +1,17 @@
1
- import streamlit as st
2
- from huggingface_hub import hf_hub_url, cached_download
3
- from PIL import Image
4
  import os
5
  import json
6
  import glob
7
  import random
8
- from typing import Any, Dict, List
9
  import torch
10
  import torchvision
11
 
 
12
  import wordsegment as ws
 
 
13
 
14
  from virtex.config import Config
15
- from virtex.factories import TokenizerFactory, PretrainingModelFactory, ImageTransformsFactory
16
  from virtex.utils.checkpointing import CheckpointManager
17
 
18
  CONFIG_PATH = "config.yaml"
@@ -20,98 +19,108 @@ MODEL_PATH = "checkpoint_last5.pth"
20
  VALID_SUBREDDITS_PATH = "subreddit_list.json"
21
  SAMPLES_PATH = "./samples/*.jpg"
22
 
23
- class ImageLoader():
 
24
  def __init__(self):
25
- self.image_transform = torchvision.transforms.Compose([
26
- torchvision.transforms.ToTensor(),
27
- torchvision.transforms.Resize(256),
28
- torchvision.transforms.CenterCrop(224),
29
- torchvision.transforms.Normalize((.485, .456, .406), (.229, .224, .225))])
30
- self.show_size=500
31
-
 
 
 
 
 
32
  def load(self, im_path):
33
  im = torch.FloatTensor(self.image_transform(Image.open(im_path))).unsqueeze(0)
34
  return {"image": im}
35
-
36
  def raw_load(self, im_path):
37
  im = torch.FloatTensor(Image.open(im_path))
38
  return {"image": im}
39
-
40
  def transform(self, image):
41
  im = torch.FloatTensor(self.image_transform(image)).unsqueeze(0)
42
  return {"image": im}
43
-
44
  def text_transform(self, text):
45
  # at present just lowercasing:
46
  return text.lower()
47
-
48
  def show_resize(self, image):
49
  # ugh we need to do this manually cuz this is pytorch==0.8 not 1.9 lol
50
  image = torchvision.transforms.functional.to_tensor(image)
51
- x,y = image.shape[-2:]
52
- ratio = float(self.show_size/max((x,y)))
53
- image = torchvision.transforms.functional.resize(image, [int(x * ratio), int(y * ratio)])
 
 
54
  return torchvision.transforms.functional.to_pil_image(image)
55
-
56
 
57
- class VirTexModel():
 
 
58
  def __init__(self):
59
  self.config = Config(CONFIG_PATH)
60
  ws.load()
61
- self.device = 'cpu'
62
  self.tokenizer = TokenizerFactory.from_config(self.config)
63
  self.model = PretrainingModelFactory.from_config(self.config).to(self.device)
64
  CheckpointManager(model=self.model).load(MODEL_PATH)
65
  self.model.eval()
66
  self.valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
67
-
68
- def predict(self, image_dict, sub_prompt = None, prompt = ""):
69
  if sub_prompt is None:
70
- subreddit_tokens = torch.tensor([self.model.sos_index], device=self.device).long()
 
 
71
  else:
72
  subreddit_tokens = " ".join(ws.segment(ws.clean(sub_prompt)))
73
  subreddit_tokens = (
74
- [self.model.sos_index] +
75
- self.tokenizer.encode(subreddit_tokens) +
76
- [self.tokenizer.token_to_id("[SEP]")]
77
- )
78
  subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
79
-
80
  if prompt is not "":
81
  # at present prompts without subreddits will break without this change
82
  # TODO FIX
83
  cap_tokens = self.tokenizer.encode(prompt)
84
  cap_tokens = torch.tensor(cap_tokens, device=self.device).long()
85
- subreddit_tokens = subreddit_tokens if sub_prompt is not None else torch.tensor(
86
- (
87
- [self.model.sos_index] +
88
- self.tokenizer.encode("pics") +
89
- [self.tokenizer.token_to_id("[SEP]")]
90
- ), device = self.device).long()
91
-
92
- subreddit_tokens = torch.cat(
93
- [
94
- subreddit_tokens,
95
- cap_tokens
96
- ])
97
-
98
-
99
- predictions: List[Dict[str, Any]] = []
100
-
101
  is_valid_subreddit = False
102
  subreddit, rest_of_caption = "", ""
103
  image_dict["decode_prompt"] = subreddit_tokens
104
  while not is_valid_subreddit:
105
-
106
  with torch.no_grad():
107
  caption = self.model(image_dict)["predictions"][0].tolist()
108
-
109
  if self.tokenizer.token_to_id("[SEP]") in caption:
110
  sep_index = caption.index(self.tokenizer.token_to_id("[SEP]"))
111
  caption[sep_index] = self.tokenizer.token_to_id("://")
112
-
113
  caption = self.tokenizer.decode(caption)
114
-
115
  if "://" in caption:
116
  subreddit, rest_of_caption = caption.split("://")
117
  subreddit = "".join(subreddit.split())
@@ -122,25 +131,29 @@ class VirTexModel():
122
  # split prompt for coloring:
123
  if prompt is not "":
124
  _, rest_of_caption = caption.split(prompt.strip())
125
-
126
  is_valid_subreddit = subreddit in self.valid_subs
127
-
128
  return subreddit, rest_of_caption
129
 
 
130
  def download_files():
131
- #download model files
132
  download_files = [CONFIG_PATH, MODEL_PATH, VALID_SUBREDDITS_PATH]
133
  for f in download_files:
134
  fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f))
135
  os.system(f"cp {fp} ./{f}")
136
 
 
137
  def get_samples():
138
  return glob.glob(SAMPLES_PATH)
139
 
 
140
  def get_rand_idx(samples):
141
- return random.randint(0,len(samples)-1)
 
142
 
143
- @st.cache(allow_output_mutation=True) # allow mutation to update nucleus size
144
  def create_objects():
145
  sample_images = get_samples()
146
  virtexModel = VirTexModel()
@@ -149,7 +162,8 @@ def create_objects():
149
  valid_subs.insert(0, None)
150
  return virtexModel, imageLoader, sample_images, valid_subs
151
 
152
- footer="""<style>
 
153
  a:link , a:visited{
154
  color: blue;
155
  background-color: transparent;
@@ -181,4 +195,4 @@ This demo accompanies our paper RedCaps.
181
  Created by Karan Desai, Gaurav Kaul, Zubin Aysola, Justin Johnson
182
  </p>
183
  </div>
184
- """
 
 
 
 
1
  import os
2
  import json
3
  import glob
4
  import random
 
5
  import torch
6
  import torchvision
7
 
8
+ import streamlit as st
9
  import wordsegment as ws
10
+ from PIL import Image
11
+ from huggingface_hub import hf_hub_url, cached_download
12
 
13
  from virtex.config import Config
14
+ from virtex.factories import TokenizerFactory, PretrainingModelFactory
15
  from virtex.utils.checkpointing import CheckpointManager
16
 
17
  CONFIG_PATH = "config.yaml"
 
19
  VALID_SUBREDDITS_PATH = "subreddit_list.json"
20
  SAMPLES_PATH = "./samples/*.jpg"
21
 
22
+
23
+ class ImageLoader:
24
  def __init__(self):
25
+ self.image_transform = torchvision.transforms.Compose(
26
+ [
27
+ torchvision.transforms.ToTensor(),
28
+ torchvision.transforms.Resize(256),
29
+ torchvision.transforms.CenterCrop(224),
30
+ torchvision.transforms.Normalize(
31
+ (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
32
+ ),
33
+ ]
34
+ )
35
+ self.show_size = 500
36
+
37
  def load(self, im_path):
38
  im = torch.FloatTensor(self.image_transform(Image.open(im_path))).unsqueeze(0)
39
  return {"image": im}
40
+
41
  def raw_load(self, im_path):
42
  im = torch.FloatTensor(Image.open(im_path))
43
  return {"image": im}
44
+
45
  def transform(self, image):
46
  im = torch.FloatTensor(self.image_transform(image)).unsqueeze(0)
47
  return {"image": im}
48
+
49
  def text_transform(self, text):
50
  # at present just lowercasing:
51
  return text.lower()
52
+
53
  def show_resize(self, image):
54
  # ugh we need to do this manually cuz this is pytorch==0.8 not 1.9 lol
55
  image = torchvision.transforms.functional.to_tensor(image)
56
+ x, y = image.shape[-2:]
57
+ ratio = float(self.show_size / max((x, y)))
58
+ image = torchvision.transforms.functional.resize(
59
+ image, [int(x * ratio), int(y * ratio)]
60
+ )
61
  return torchvision.transforms.functional.to_pil_image(image)
 
62
 
63
+
64
+ class VirTexModel:
65
+
66
  def __init__(self):
67
  self.config = Config(CONFIG_PATH)
68
  ws.load()
69
+ self.device = "cpu"
70
  self.tokenizer = TokenizerFactory.from_config(self.config)
71
  self.model = PretrainingModelFactory.from_config(self.config).to(self.device)
72
  CheckpointManager(model=self.model).load(MODEL_PATH)
73
  self.model.eval()
74
  self.valid_subs = json.load(open(VALID_SUBREDDITS_PATH))
75
+
76
+ def predict(self, image_dict, sub_prompt=None, prompt=""):
77
  if sub_prompt is None:
78
+ subreddit_tokens = torch.tensor(
79
+ [self.model.sos_index], device=self.device
80
+ ).long()
81
  else:
82
  subreddit_tokens = " ".join(ws.segment(ws.clean(sub_prompt)))
83
  subreddit_tokens = (
84
+ [self.model.sos_index]
85
+ + self.tokenizer.encode(subreddit_tokens)
86
+ + [self.tokenizer.token_to_id("[SEP]")]
87
+ )
88
  subreddit_tokens = torch.tensor(subreddit_tokens, device=self.device).long()
89
+
90
  if prompt is not "":
91
  # at present prompts without subreddits will break without this change
92
  # TODO FIX
93
  cap_tokens = self.tokenizer.encode(prompt)
94
  cap_tokens = torch.tensor(cap_tokens, device=self.device).long()
95
+ subreddit_tokens = (
96
+ subreddit_tokens
97
+ if sub_prompt is not None
98
+ else torch.tensor(
99
+ (
100
+ [self.model.sos_index]
101
+ + self.tokenizer.encode("pics")
102
+ + [self.tokenizer.token_to_id("[SEP]")]
103
+ ),
104
+ device=self.device,
105
+ ).long()
106
+ )
107
+
108
+ subreddit_tokens = torch.cat([subreddit_tokens, cap_tokens])
109
+
 
110
  is_valid_subreddit = False
111
  subreddit, rest_of_caption = "", ""
112
  image_dict["decode_prompt"] = subreddit_tokens
113
  while not is_valid_subreddit:
114
+
115
  with torch.no_grad():
116
  caption = self.model(image_dict)["predictions"][0].tolist()
117
+
118
  if self.tokenizer.token_to_id("[SEP]") in caption:
119
  sep_index = caption.index(self.tokenizer.token_to_id("[SEP]"))
120
  caption[sep_index] = self.tokenizer.token_to_id("://")
121
+
122
  caption = self.tokenizer.decode(caption)
123
+
124
  if "://" in caption:
125
  subreddit, rest_of_caption = caption.split("://")
126
  subreddit = "".join(subreddit.split())
 
131
  # split prompt for coloring:
132
  if prompt is not "":
133
  _, rest_of_caption = caption.split(prompt.strip())
134
+
135
  is_valid_subreddit = subreddit in self.valid_subs
136
+
137
  return subreddit, rest_of_caption
138
 
139
+
140
  def download_files():
141
+ # download model files
142
  download_files = [CONFIG_PATH, MODEL_PATH, VALID_SUBREDDITS_PATH]
143
  for f in download_files:
144
  fp = cached_download(hf_hub_url("zamborg/redcaps", filename=f))
145
  os.system(f"cp {fp} ./{f}")
146
 
147
+
148
  def get_samples():
149
  return glob.glob(SAMPLES_PATH)
150
 
151
+
152
  def get_rand_idx(samples):
153
+ return random.randint(0, len(samples) - 1)
154
+
155
 
156
+ @st.cache(allow_output_mutation=True) # allow mutation to update nucleus size
157
  def create_objects():
158
  sample_images = get_samples()
159
  virtexModel = VirTexModel()
 
162
  valid_subs.insert(0, None)
163
  return virtexModel, imageLoader, sample_images, valid_subs
164
 
165
+
166
+ footer = """<style>
167
  a:link , a:visited{
168
  color: blue;
169
  background-color: transparent;
 
195
  Created by Karan Desai, Gaurav Kaul, Zubin Aysola, Justin Johnson
196
  </p>
197
  </div>
198
+ """
virtex/CHANGELOG.md DELETED
@@ -1,41 +0,0 @@
1
- ArXiv v1 -> v2 CHANGELOG
2
- =========================
3
-
4
- [ArXiv v1](https://arxiv.org/abs/2006.06666v1) was our ECCV 2020 submission (reject). [ArXiv v2](https://arxiv.org/abs/2006.06666v2) is out CVPR 2021 submission (accept). The repository snapshots for these two versions are tagged at [`v0.9`](https://github.com/kdexd/virtex/releases/tag/v0.9) and [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0).
5
-
6
- While the core motivation and approach is the same, we have made some minor changes in our experiments and evaluation setup. These slightly improve model performances across the board (within decimals). New models are available in [`v1.0` model zoo](http://kdexd.github.io/virtex/virtex/usage/model_zoo.html), however links to old models in `v0.9` will be active till June 30, 2021. We encourage you to use the new models!
7
-
8
- We have updated the experiment config files for all changes described below.
9
-
10
- Experiment Changes
11
- ------------------
12
-
13
- ### New Feature:
14
-
15
- Add a new pretraining task for BERT-style _Masked Language Modeling_. Pre-trained model released in Model Zoo.
16
-
17
- ### Pre-training:
18
-
19
- - The only change during pre-training is that we do not apply weight decay to LayerNorm and biases in input embedding and transformer layers. We apply weight decay to the biases in output linear layer (before softmax).
20
-
21
- - Other factors that could affect results:
22
- - Use official [albumentations.ColorJitter transform](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ColorJitter) that mimics torchvision ColorJitter transform. Earlier I implemented [my own ColorJitter](https://github.com/kdexd/virtex/blob/c19e7fc9b98e98af82286ed1537b6f588eaeac44/virtex/data/transforms.py#L156) because albumentations didn't have one.
23
- - Use PyTorch Native AMP (Automatic Mixed Precision) instead of NVIDIA Apex.
24
-
25
- ### Downstream Evaluations:
26
-
27
- 1. **PASCAL VOC 2007 Linear Classification:** [[diff]](https://github.com/kdexd/virtex/compare/57889ca9829f27b932e92b9e6b51f50f20f2d546..7645cc0d1e3e49f00e347e9873fd020faa2ec62e#diff-b4405dd4879a48ef1e5b1e2801035909584a5f1f32f63d5e793fb50dee077b97)
28
- - Instead of training linear SVMs on 8192-dimensional average pooled features from ResNet-50 (7x7x2048 β€”> 2x2x2048), like [(Misra et al. 2019)](https://arxiv.org/abs/1905.01235), we directly train SVMs on 2048-dimensional global average pooled features, following recent works like [SwAV (Caron et al. 2020)](https://arxiv.org/abs/2006.09882).
29
- - We change the pre-processing: resize shortest edge to 256 pixels, and take center crop of 224 pixels.
30
- - These improve VOC mAP by 1-2 points everywhere, and makes SVM training faster. Since we select best checkpoint based on this metric, all results on other downstream tasks also change in `ArXiv v2` (But the trends remain same.)
31
-
32
- 2. **ImageNet Linear Evaluation:** [[diff]](https://github.com/kdexd/virtex/compare/57889ca9829f27b932e92b9e6b51f50f20f2d546..7645cc0d1e3e49f00e347e9873fd020faa2ec62e#diff-d3dea1e7bf97d0cfca4b59a47c0a9bb81e78b8827654fe0258df9ce2c3f5f41c)
33
- - Changed random resized crop scale from (20-100%) to (8-100%) for consistency with evaluations in SSL works like MoCo and SwAV.
34
- - Use cosine LR decay instead of step decay, following SwAV. Improves accuracy by up to 1%.
35
-
36
- 3. **iNaturalist Fine-tuning:** [[diff]](https://github.com/kdexd/virtex/compare/57889ca9829f27b932e92b9e6b51f50f20f2d546..7645cc0d1e3e49f00e347e9873fd020faa2ec62e#diff-09096da78cfcde3a604ce22d80313f0800225d928cce5ef7334b89a382adfe4d)
37
- - This evaluation is left unchanged across ArXiv versions, but we fixd a typo in image pre-processing step, present in publicly released config.
38
-
39
- 4. **Detectron2 tasks (COCO and LVIS Instance Segmentation, VOC Detection):**
40
- - Heavily simplified the script. Updated Detectron2 uses a more memory-efficient SyncBatchNorm and supports AMP.
41
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/LICENSE DELETED
@@ -1,16 +0,0 @@
1
- Copyright (c) 2020, Karan Desai.
2
-
3
- Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
4
- associated documentation files (the "Software"), to deal in the Software without restriction,
5
- including without limitation the rights to use, copy, modify, merge, publish, distribute,
6
- sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
7
- furnished to do so, subject to the following conditions:
8
-
9
- The above copyright notice and this permission notice shall be included in all copies or substantial
10
- portions of the Software.
11
-
12
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
13
- NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
14
- NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
15
- OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
16
- CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/README.md DELETED
@@ -1,92 +0,0 @@
1
- VirTex: Learning Visual Representations from Textual Annotations
2
- ================================================================
3
-
4
- <h4>
5
- Karan Desai and Justin Johnson
6
- </br>
7
- <span style="font-size: 14pt; color: #555555">
8
- University of Michigan
9
- </span>
10
- </h4>
11
- <hr>
12
-
13
- **CVPR 2021** [arxiv.org/abs/2006.06666][1]
14
-
15
- **Model Zoo, Usage Instructions and API docs:** [kdexd.github.io/virtex](https://kdexd.github.io/virtex)
16
-
17
- VirTex is a pretraining approach which uses semantically dense captions to
18
- learn visual representations. We train CNN + Transformers from scratch on
19
- COCO Captions, and transfer the CNN to downstream vision tasks including
20
- image classification, object detection, and instance segmentation.
21
- VirTex matches or outperforms models which use ImageNet for pretraining --
22
- both supervised or unsupervised -- despite using up to 10x fewer images.
23
-
24
- ![virtex-model](docs/_static/system_figure.jpg)
25
-
26
-
27
- Get the pretrained ResNet-50 visual backbone from our best performing VirTex
28
- model in one line *without any installation*!
29
-
30
- ```python
31
- import torch
32
-
33
- # That's it, this one line only requires PyTorch.
34
- model = torch.hub.load("kdexd/virtex", "resnet50", pretrained=True)
35
- ```
36
-
37
- ### Note (For returning users before January 2021):
38
-
39
- The pretrained models in our model zoo have changed from [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0) onwards.
40
- They are slightly better tuned than older models, and reproduce the results in our
41
- CVPR 2021 accepted paper ([arXiv v2](https://arxiv.org/abs/2006.06666v2)).
42
- Some training and evaluation hyperparams are changed since [`v0.9`](https://github.com/kdexd/virtex/releases/tag/v0.9).
43
- Please refer [`CHANGELOG.md`](https://github.com/kdexd/virtex/blob/master/CHANGELOG.md)
44
-
45
-
46
- Usage Instructions
47
- ------------------
48
-
49
- 1. [How to setup this codebase?][2]
50
- 2. [VirTex Model Zoo][3]
51
- 3. [How to train your VirTex model?][4]
52
- 4. [How to evaluate on downstream tasks?][5]
53
-
54
- Full documentation is available at [kdexd.github.io/virtex](https://kdexd.github.io/virtex).
55
-
56
-
57
- Citation
58
- --------
59
-
60
- If you find this code useful, please consider citing:
61
-
62
- ```text
63
- @inproceedings{desai2021virtex,
64
- title={{VirTex: Learning Visual Representations from Textual Annotations}},
65
- author={Karan Desai and Justin Johnson},
66
- booktitle={CVPR},
67
- year={2021}
68
- }
69
- ```
70
-
71
- Acknowledgments
72
- ---------------
73
-
74
- We thank Harsh Agrawal, Mohamed El Banani, Richard Higgins, Nilesh Kulkarni
75
- and Chris Rockwell for helpful discussions and feedback on the paper. We thank
76
- Ishan Misra for discussions regarding PIRL evaluation protocol; Saining Xie for
77
- discussions about replicating iNaturalist evaluation as MoCo; Ross Girshick and
78
- Yuxin Wu for help with Detectron2 model zoo; Georgia Gkioxari for suggesting
79
- the Instance Segmentation pretraining task ablation; and Stefan Lee for
80
- suggestions on figure aesthetics. We thank Jia Deng for access to extra GPUs
81
- during project development; and UMich ARC-TS team for support with GPU cluster
82
- management. Finally, we thank all the Starbucks outlets in Ann Arbor for many
83
- hours of free WiFi. This work was partially supported by the Toyota Research
84
- Institute (TRI). However, note that this article solely reflects the opinions
85
- and conclusions of its authors and not TRI or any other Toyota entity.
86
-
87
-
88
- [1]: https://arxiv.org/abs/2006.06666
89
- [2]: https://kdexd.github.io/virtex/virtex/usage/setup_dependencies.html
90
- [3]: https://kdexd.github.io/virtex/virtex/usage/model_zoo.html
91
- [4]: https://kdexd.github.io/virtex/virtex/usage/pretrain.html
92
- [5]: https://kdexd.github.io/virtex/virtex/usage/downstream.html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/{virtex/__init__.py β†’ __init__.py} RENAMED
File without changes
virtex/{virtex/config.py β†’ config.py} RENAMED
File without changes
virtex/configs/_base_bicaptioning_R_50_L1_H1024.yaml DELETED
@@ -1,66 +0,0 @@
1
- # -----------------------------------------------------------------------------
2
- # Base config: VirTex pretraining for our "base" bicaptioning model:
3
- # ResNet-50 + (L = 1, H = 1024) transformer trained for 500K iterations.
4
- # -----------------------------------------------------------------------------
5
- RANDOM_SEED: 0
6
- AMP: true
7
- CUDNN_BENCHMARK: true
8
- CUDNN_DETERMINISTIC: false
9
-
10
- DATA:
11
- ROOT: "datasets/coco"
12
- TOKENIZER_MODEL: "datasets/vocab/coco_10k.model"
13
- VOCAB_SIZE: 10000
14
- UNK_INDEX: 0
15
- SOS_INDEX: 1
16
- EOS_INDEX: 2
17
- MASK_INDEX: 3
18
-
19
- IMAGE_CROP_SIZE: 224
20
- MAX_CAPTION_LENGTH: 30
21
-
22
- IMAGE_TRANSFORM_TRAIN:
23
- - "random_resized_crop"
24
- - "horizontal_flip"
25
- - "color_jitter"
26
- - "normalize"
27
-
28
- IMAGE_TRANSFORM_VAL:
29
- - "smallest_resize"
30
- - "center_crop"
31
- - "normalize"
32
-
33
- USE_PERCENTAGE: 100.0
34
- USE_SINGLE_CAPTION: false
35
-
36
- MODEL:
37
- NAME: "virtex"
38
- VISUAL:
39
- NAME: "torchvision::resnet50"
40
- PRETRAINED: false
41
- FROZEN: false
42
- TEXTUAL:
43
- NAME: "transdec_postnorm::L1_H1024_A16_F4096"
44
- DROPOUT: 0.1
45
-
46
- OPTIM:
47
- OPTIMIZER_NAME: "sgd"
48
- SGD_MOMENTUM: 0.9
49
- WEIGHT_DECAY: 0.0001
50
-
51
- LOOKAHEAD:
52
- USE: true
53
- ALPHA: 0.5
54
- STEPS: 5
55
-
56
- BATCH_SIZE: 256
57
- CNN_LR: 0.2
58
- LR: 0.001
59
- NUM_ITERATIONS: 500000
60
-
61
- WARMUP_STEPS: 10000
62
- LR_DECAY_NAME: "cosine"
63
-
64
- NO_DECAY: ".*textual.(embedding|transformer).*(norm.*|bias)"
65
- CLIP_GRAD_NORM: 10.0
66
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/configs/backbone_ablations/bicaptioning_R_101_L1_H1024.yaml DELETED
@@ -1,5 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- MODEL:
4
- VISUAL:
5
- NAME: "torchvision::resnet101"
 
 
 
 
 
 
virtex/configs/backbone_ablations/bicaptioning_R_50W2X_L1_H1024.yaml DELETED
@@ -1,5 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- MODEL:
4
- VISUAL:
5
- NAME: "torchvision::wide_resnet50_2"
 
 
 
 
 
 
virtex/configs/backbone_ablations/bicaptioning_R_50_L1_H1024.yaml DELETED
@@ -1 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
 
 
virtex/configs/depth_ablations/bicaptioning_R_50_L1_H1024.yaml DELETED
@@ -1 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
 
 
virtex/configs/depth_ablations/bicaptioning_R_50_L2_H1024.yaml DELETED
@@ -1,5 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- MODEL:
4
- TEXTUAL:
5
- NAME: "transdec_postnorm::L2_H1024_A16_F4096"
 
 
 
 
 
 
virtex/configs/depth_ablations/bicaptioning_R_50_L3_H1024.yaml DELETED
@@ -1,5 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- MODEL:
4
- TEXTUAL:
5
- NAME: "transdec_postnorm::L3_H1024_A16_F4096"
 
 
 
 
 
 
virtex/configs/depth_ablations/bicaptioning_R_50_L4_H1024.yaml DELETED
@@ -1,5 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- MODEL:
4
- TEXTUAL:
5
- NAME: "transdec_postnorm::L4_H1024_A16_F4096"
 
 
 
 
 
 
virtex/configs/detectron2/_base_faster_rcnn_R_50_C4_BN.yaml DELETED
@@ -1,49 +0,0 @@
1
- # ----------------------------------------------------------------------------
2
- # Train a Faster R-CNN with ResNet-50 and C4 backbone. This config follows
3
- # Detectron2 format; and is unrelated with our VirTex configs. Params here
4
- # replicate evaluation protocol as per MoCo (https://arxiv.org/abs/1911.05722).
5
- # ----------------------------------------------------------------------------
6
-
7
- INPUT:
8
- # Input format will always be RGB, consistent with torchvision.
9
- FORMAT: "RGB"
10
- MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
11
- MIN_SIZE_TEST: 800
12
-
13
- MODEL:
14
- META_ARCHITECTURE: "GeneralizedRCNN"
15
-
16
- # Train all layers end-to-end by default.
17
- BACKBONE:
18
- NAME: build_resnet_backbone
19
- FREEZE_AT: 0
20
-
21
- # Fine-tune with SyncBN.
22
- # STRIDE_IN_1X1 is False for torchvision-like models.
23
- RESNETS:
24
- DEPTH: 50
25
- NORM: SyncBN
26
- STRIDE_IN_1X1: False
27
-
28
- RPN:
29
- PRE_NMS_TOPK_TEST: 6000
30
- POST_NMS_TOPK_TEST: 1000
31
-
32
- # ROI head with extra BN layer after res5 stage.
33
- ROI_HEADS:
34
- NAME: "Res5ROIHeadsExtraNorm"
35
-
36
- # ImageNet color mean for torchvision-like models (RGB order).
37
- PIXEL_MEAN: [123.675, 116.280, 103.530]
38
- PIXEL_STD: [58.395, 57.120, 57.375]
39
-
40
- SOLVER:
41
- # This is for 8 GPUs, apply linear scaling for 4 GPUs.
42
- IMS_PER_BATCH: 16
43
- BASE_LR: 0.02
44
-
45
- TEST:
46
- PRECISE_BN:
47
- ENABLED: True
48
-
49
- VERSION: 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/configs/detectron2/_base_mask_rcnn_R_50_FPN.yaml DELETED
@@ -1,75 +0,0 @@
1
- # ----------------------------------------------------------------------------
2
- # Train a Mask R-CNN with ResNet-50 and FPN backbone. This config follows
3
- # Detectron2 format; and is unrelated with our VirTex configs. Params here
4
- # replicate evaluation protocol as per MoCo (https://arxiv.org/abs/1911.05722).
5
- # ----------------------------------------------------------------------------
6
-
7
- INPUT:
8
- # Input format will always be RGB, consistent with torchvision.
9
- FORMAT: "RGB"
10
- MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
11
- MIN_SIZE_TEST: 800
12
-
13
- MODEL:
14
- META_ARCHITECTURE: "GeneralizedRCNN"
15
-
16
- # Train all layers end-to-end by default.
17
- BACKBONE:
18
- NAME: "build_resnet_fpn_backbone"
19
- FREEZE_AT: 0
20
-
21
- # Fine-tune with SyncBN.
22
- # STRIDE_IN_1X1 is False for torchvision-like models.
23
- RESNETS:
24
- DEPTH: 50
25
- NORM: "SyncBN"
26
- STRIDE_IN_1X1: False
27
- OUT_FEATURES: ["res2", "res3", "res4", "res5"]
28
-
29
- FPN:
30
- IN_FEATURES: ["res2", "res3", "res4", "res5"]
31
-
32
- ANCHOR_GENERATOR:
33
- # One size for each in feature map
34
- SIZES: [[32], [64], [128], [256], [512]]
35
- # Three aspect ratios (same for all in feature maps)
36
- ASPECT_RATIOS: [[0.5, 1.0, 2.0]]
37
-
38
- RPN:
39
- IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
40
- PRE_NMS_TOPK_TRAIN: 2000
41
- PRE_NMS_TOPK_TEST: 1000
42
-
43
- POST_NMS_TOPK_TRAIN: 1000
44
- POST_NMS_TOPK_TEST: 1000
45
-
46
- ROI_HEADS:
47
- NAME: "StandardROIHeads"
48
- IN_FEATURES: ["p2", "p3", "p4", "p5"]
49
-
50
- ROI_BOX_HEAD:
51
- NAME: "FastRCNNConvFCHead"
52
- NUM_FC: 2
53
- POOLER_RESOLUTION: 7
54
-
55
- ROI_MASK_HEAD:
56
- NAME: "MaskRCNNConvUpsampleHead"
57
- NUM_CONV: 4
58
- POOLER_RESOLUTION: 14
59
-
60
- # ImageNet color mean for torchvision-like models (RGB order).
61
- # These are in [0-255] range as expected by Detectron2. Rest of our codebase
62
- # uses [0-1] range; but both are equivalent and consistent.
63
- PIXEL_MEAN: [123.675, 116.280, 103.530]
64
- PIXEL_STD: [58.395, 57.120, 57.375]
65
-
66
- SOLVER:
67
- # This is for 8 GPUs, apply linear scaling for 4 GPUs.
68
- IMS_PER_BATCH: 16
69
- BASE_LR: 0.02
70
-
71
- TEST:
72
- PRECISE_BN:
73
- ENABLED: True
74
-
75
- VERSION: 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/configs/detectron2/coco_segm_default_init_2x.yaml DELETED
@@ -1,24 +0,0 @@
1
- # -----------------------------------------------------------------------------
2
- # Train a Mask R-CNN R50-FPN backbone on LVIS instance segmentation with any of
3
- # these weight init: random, imagenet (torchvision), virtex or MoCo.
4
- # -----------------------------------------------------------------------------
5
- _BASE_: "_base_mask_rcnn_R_50_FPN.yaml"
6
-
7
- DATASETS:
8
- TRAIN: ("coco_2017_train",)
9
- TEST: ("coco_2017_val",)
10
-
11
- MODEL:
12
- MASK_ON: True
13
- # FPN also has SyncBN, as opposed to no norm (usually).
14
- FPN:
15
- NORM: "SyncBN"
16
-
17
- # This will be ignored, weights will be loaded manually in the script.
18
- WEIGHTS: ""
19
-
20
- SOLVER:
21
- STEPS: (120000, 160000)
22
- MAX_ITER: 180000
23
-
24
- VERSION: 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/configs/detectron2/lvis_segm_default_init_2x.yaml DELETED
@@ -1,36 +0,0 @@
1
- # -----------------------------------------------------------------------------
2
- # Train a Mask R-CNN R50-FPN backbone on LVIS instance segmentation with any of
3
- # these weight init: random, virtex or MoCo. (ImageNet init config is separate)
4
- # -----------------------------------------------------------------------------
5
- _BASE_: "_base_mask_rcnn_R_50_FPN.yaml"
6
-
7
- DATASETS:
8
- TRAIN: ("lvis_v1_train",)
9
- TEST: ("lvis_v1_val",)
10
-
11
- DATALOADER:
12
- SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
13
- REPEAT_THRESHOLD: 0.001
14
-
15
- TEST:
16
- DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300.
17
-
18
- MODEL:
19
- MASK_ON: True
20
- # FPN also has SyncBN, as opposed to no norm (usually).
21
- FPN:
22
- NORM: "SyncBN"
23
-
24
- ROI_HEADS:
25
- NUM_CLASSES: 1203
26
- SCORE_THRESH_TEST: 0.0001
27
-
28
- # This will be ignored, weights will be loaded manually in the script.
29
- WEIGHTS: ""
30
-
31
- SOLVER:
32
- STEPS: (120000, 160000)
33
- MAX_ITER: 180000
34
-
35
- VERSION: 2
36
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/configs/detectron2/lvis_segm_imagenet_init_2x.yaml DELETED
@@ -1,38 +0,0 @@
1
- # -----------------------------------------------------------------------------
2
- # Train a Mask R-CNN R50-FPN backbone on LVIS instance segmentation
3
- # with weights initialized from supervised ImageNet pretraining (torchvision).
4
- # Key difference is that fine-tuning here happens with BN frozen.
5
- # -----------------------------------------------------------------------------
6
- _BASE_: "_base_mask_rcnn_R_50_FPN.yaml"
7
-
8
- DATASETS:
9
- TRAIN: ("lvis_v1_train",)
10
- TEST: ("lvis_v1_val",)
11
-
12
- DATALOADER:
13
- SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
14
- REPEAT_THRESHOLD: 0.001
15
-
16
- TEST:
17
- DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300.
18
-
19
- MODEL:
20
- MASK_ON: True
21
- RESNETS:
22
- NORM: "FrozenBN"
23
-
24
- # Do not tune with SyncBN for ImageNet init from LVIS.
25
- ROI_HEADS:
26
- NUM_CLASSES: 1203
27
- SCORE_THRESH_TEST: 0.0001
28
-
29
- # This will be ignored, weights will be loaded manually in the script.
30
- WEIGHTS: ""
31
-
32
- SOLVER:
33
- STEPS: (120000, 160000)
34
- MAX_ITER: 180000
35
-
36
- VERSION: 2
37
-
38
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/configs/detectron2/voc_det_default_init_24k.yaml DELETED
@@ -1,28 +0,0 @@
1
- # -----------------------------------------------------------------------------
2
- # Train a Faster R-CNN with R50-C4 backbone on VOC07+12 detection with any of
3
- # these weight init: random, imagenet (torchvision), virtex or MoCo.
4
- # -----------------------------------------------------------------------------
5
- _BASE_: "_base_faster_rcnn_R_50_C4_BN.yaml"
6
-
7
- DATASETS:
8
- TRAIN: ("voc_2007_trainval", "voc_2012_trainval")
9
- TEST: ("voc_2007_test",)
10
-
11
- INPUT:
12
- MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
13
- MIN_SIZE_TEST: 800
14
-
15
- MODEL:
16
- MASK_ON: False
17
- ROI_HEADS:
18
- NUM_CLASSES: 20
19
-
20
- # This will be ignored, weights will be loaded manually in the script.
21
- WEIGHTS: ""
22
-
23
- SOLVER:
24
- STEPS: (18000, 22000)
25
- MAX_ITER: 24000
26
- WARMUP_ITERS: 100
27
-
28
- VERSION: 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/configs/downstream/imagenet_clf.yaml DELETED
@@ -1,33 +0,0 @@
1
- RANDOM_SEED: 0
2
- # Don't need AMP to train a tiny linear layer.
3
- AMP: false
4
- CUDNN_BENCHMARK: true
5
- CUDNN_DETERMINISTIC: false
6
-
7
- DATA:
8
- ROOT: "datasets/imagenet"
9
- IMAGE_TRANSFORM_TRAIN:
10
- - "random_resized_crop::{'scale': (0.08, 1.0)}"
11
- - "horizontal_flip"
12
- - "normalize"
13
- IMAGE_TRANSFORM_VAL:
14
- - "smallest_resize"
15
- - "center_crop"
16
- - "normalize"
17
-
18
- MODEL:
19
- VISUAL:
20
- FROZEN: true
21
-
22
- OPTIM:
23
- BATCH_SIZE: 256
24
- SGD_MOMENTUM: 0.9
25
- WEIGHT_DECAY: 0.0
26
- NO_DECAY: "none"
27
- LOOKAHEAD:
28
- USE: false
29
-
30
- LR: 0.3
31
- WARMUP_STEPS: 0
32
- LR_DECAY_NAME: "cosine"
33
- NUM_ITERATIONS: 500500 # 100 epochs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/configs/downstream/inaturalist_clf.yaml DELETED
@@ -1,36 +0,0 @@
1
- RANDOM_SEED: 0
2
- AMP: true
3
- CUDNN_BENCHMARK: true
4
- CUDNN_DETERMINISTIC: false
5
-
6
- DATA:
7
- ROOT: "datasets/inaturalist"
8
- IMAGE_TRANSFORM_TRAIN:
9
- - "random_resized_crop::{'scale': (0.08, 1.0)}"
10
- - "horizontal_flip"
11
- - "normalize"
12
- IMAGE_TRANSFORM_VAL:
13
- - "smallest_resize"
14
- - "center_crop"
15
- - "normalize"
16
-
17
- MODEL:
18
- VISUAL:
19
- FROZEN: false
20
-
21
- OPTIM:
22
- BATCH_SIZE: 256
23
- SGD_MOMENTUM: 0.9
24
- WEIGHT_DECAY: 0.0001
25
- NO_DECAY: "none"
26
- LOOKAHEAD:
27
- USE: false
28
-
29
- LR: 0.025
30
- WARMUP_STEPS: 0
31
- LR_DECAY_NAME: multistep
32
- LR_GAMMA: 0.1
33
- LR_STEPS:
34
- - 119700 # 70 epochs
35
- - 153900 # 90 epochs
36
- NUM_ITERATIONS: 171000 # 100 epochs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/configs/downstream/voc07_clf.yaml DELETED
@@ -1,15 +0,0 @@
1
- RANDOM_SEED: 0
2
- DATA:
3
- ROOT: datasets/VOC2007
4
- IMAGE_TRANSFORM_TRAIN:
5
- - smallest_resize
6
- - center_crop
7
- - normalize
8
- IMAGE_TRANSFORM_VAL:
9
- - smallest_resize
10
- - center_crop
11
- - normalize
12
-
13
- OPTIM:
14
- # Only used for feature extraction, doesn't mean much.
15
- BATCH_SIZE: 128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/configs/redcaps/gcc_R_50_L6_H512.yaml DELETED
@@ -1,35 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- AMP: True
4
-
5
- DATA:
6
- ROOT: "datasets/gcc/tarfiles/*.tar"
7
- TOKENIZER_MODEL: "datasets/vocab/common_30k.model"
8
- VOCAB_SIZE: 30000
9
- UNK_INDEX: 0
10
- SOS_INDEX: 1
11
- EOS_INDEX: 2
12
- MASK_INDEX: 3
13
-
14
- MAX_CAPTION_LENGTH: 50
15
-
16
- MODEL:
17
- NAME: "virtex_web"
18
- TEXTUAL:
19
- NAME: "transdec_prenorm::L6_H512_A8_F2048"
20
-
21
- LABEL_SMOOTHING: 0.1
22
-
23
- OPTIM:
24
- OPTIMIZER_NAME: "adamw"
25
- WEIGHT_DECAY: 0.01
26
- LOOKAHEAD:
27
- USE: false
28
-
29
- BATCH_SIZE: 256
30
- CNN_LR: 0.0005
31
- LR: 0.0005
32
- NUM_ITERATIONS: 1500000
33
-
34
- WARMUP_STEPS: 10000
35
- LR_DECAY_NAME: "cosine"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/configs/redcaps/miniclip_sbu_R_50_L12_H512.yaml DELETED
@@ -1,35 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- AMP: True
4
-
5
- DATA:
6
- ROOT: "datasets/sbu/tarfiles/*.tar"
7
- TOKENIZER_MODEL: "datasets/vocab/common_30k.model"
8
- VOCAB_SIZE: 30000
9
- UNK_INDEX: 0
10
- SOS_INDEX: 1
11
- EOS_INDEX: 2
12
- MASK_INDEX: 3
13
-
14
- MAX_CAPTION_LENGTH: 50
15
-
16
- MODEL:
17
- NAME: "miniclip_web"
18
- TEXTUAL:
19
- NAME: "transenc_prenorm::L12_H512_A8_F2048"
20
- LABEL_SMOOTHING: 0.1
21
-
22
- OPTIM:
23
- OPTIMIZER_NAME: "adamw"
24
- WEIGHT_DECAY: 0.01
25
-
26
- LOOKAHEAD:
27
- USE: false
28
-
29
- BATCH_SIZE: 256
30
- CNN_LR: 0.0005
31
- LR: 0.0005
32
- NUM_ITERATIONS: 1500000
33
-
34
- WARMUP_STEPS: 10000
35
- LR_DECAY_NAME: "cosine"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/configs/redcaps/redcaps_2020_R_50_L6_H512.yaml DELETED
@@ -1,35 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- AMP: True
4
-
5
- DATA:
6
- ROOT: "datasets/redcaps/tarfiles/*_2020_*.tar"
7
- TOKENIZER_MODEL: "datasets/vocab/common_30k.model"
8
- VOCAB_SIZE: 30000
9
- UNK_INDEX: 0
10
- SOS_INDEX: 1
11
- EOS_INDEX: 2
12
- MASK_INDEX: 3
13
-
14
- MAX_CAPTION_LENGTH: 50
15
-
16
- MODEL:
17
- NAME: "virtex_web"
18
- TEXTUAL:
19
- NAME: "transdec_prenorm::L6_H512_A8_F2048"
20
- LABEL_SMOOTHING: 0.1
21
-
22
- OPTIM:
23
- OPTIMIZER_NAME: "adamw"
24
- WEIGHT_DECAY: 0.01
25
-
26
- LOOKAHEAD:
27
- USE: false
28
-
29
- BATCH_SIZE: 256
30
- CNN_LR: 0.0005
31
- LR: 0.0005
32
- NUM_ITERATIONS: 1500000
33
-
34
- WARMUP_STEPS: 10000
35
- LR_DECAY_NAME: "cosine"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/configs/redcaps/redcaps_all_R_50_L6_H512.yaml DELETED
@@ -1,35 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- AMP: True
4
-
5
- DATA:
6
- ROOT: "datasets/redcaps/tarfiles/*.tar"
7
- TOKENIZER_MODEL: "datasets/vocab/common_30k.model"
8
- VOCAB_SIZE: 30000
9
- UNK_INDEX: 0
10
- SOS_INDEX: 1
11
- EOS_INDEX: 2
12
- MASK_INDEX: 3
13
-
14
- MAX_CAPTION_LENGTH: 50
15
-
16
- MODEL:
17
- NAME: "virtex_web"
18
- TEXTUAL:
19
- NAME: "transdec_prenorm::L6_H512_A8_F2048"
20
- LABEL_SMOOTHING: 0.1
21
-
22
- OPTIM:
23
- OPTIMIZER_NAME: "adamw"
24
- WEIGHT_DECAY: 0.01
25
-
26
- LOOKAHEAD:
27
- USE: false
28
-
29
- BATCH_SIZE: 256
30
- CNN_LR: 0.0005
31
- LR: 0.0005
32
- NUM_ITERATIONS: 1500000
33
-
34
- WARMUP_STEPS: 10000
35
- LR_DECAY_NAME: "cosine"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/configs/redcaps/sbu_R_50_L6_H512.yaml DELETED
@@ -1,35 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- AMP: True
4
-
5
- DATA:
6
- ROOT: "datasets/sbu/tarfiles/*.tar"
7
- TOKENIZER_MODEL: "datasets/vocab/common_30k.model"
8
- VOCAB_SIZE: 30000
9
- UNK_INDEX: 0
10
- SOS_INDEX: 1
11
- EOS_INDEX: 2
12
- MASK_INDEX: 3
13
-
14
- MAX_CAPTION_LENGTH: 50
15
-
16
- MODEL:
17
- NAME: "virtex_web"
18
- TEXTUAL:
19
- NAME: "transdec_prenorm::L6_H512_A8_F2048"
20
- LABEL_SMOOTHING: 0.1
21
-
22
- OPTIM:
23
- OPTIMIZER_NAME: "adamw"
24
- WEIGHT_DECAY: 0.01
25
-
26
- LOOKAHEAD:
27
- USE: false
28
-
29
- BATCH_SIZE: 256
30
- CNN_LR: 0.0005
31
- LR: 0.0005
32
- NUM_ITERATIONS: 1500000
33
-
34
- WARMUP_STEPS: 10000
35
- LR_DECAY_NAME: "cosine"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/configs/task_ablations/bicaptioning_R_50_L1_H2048.yaml DELETED
@@ -1,5 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- MODEL:
4
- TEXTUAL:
5
- NAME: "transdec_postnorm::L1_H2048_A32_F8192"
 
 
 
 
 
 
virtex/configs/task_ablations/captioning_R_50_L1_H2048.yaml DELETED
@@ -1,6 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- MODEL:
4
- NAME: "captioning"
5
- TEXTUAL:
6
- NAME: "transdec_postnorm::L1_H2048_A32_F8192"
 
 
 
 
 
 
 
virtex/configs/task_ablations/masked_lm_R_50_L1_H2048.yaml DELETED
@@ -1,6 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- MODEL:
4
- NAME: "masked_lm"
5
- TEXTUAL:
6
- NAME: "transdec_postnorm::L1_H2048_A32_F8192"
 
 
 
 
 
 
 
virtex/configs/task_ablations/multilabel_classification_R_50.yaml DELETED
@@ -1,12 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- DATA:
4
- VOCAB_SIZE: 81
5
-
6
- MODEL:
7
- NAME: "multilabel_classification"
8
- TEXTUAL:
9
- NAME: "none"
10
-
11
- OPTIM:
12
- NO_DECAY: "none"
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/configs/task_ablations/token_classification_R_50.yaml DELETED
@@ -1,9 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- MODEL:
4
- NAME: "token_classification"
5
- TEXTUAL:
6
- NAME: "none"
7
-
8
- OPTIM:
9
- NO_DECAY: "none"
 
 
 
 
 
 
 
 
 
 
virtex/configs/width_ablations/bicaptioning_R_50_L1_H1024.yaml DELETED
@@ -1 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
 
 
virtex/configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml DELETED
@@ -1,5 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- MODEL:
4
- TEXTUAL:
5
- NAME: "transdec_postnorm::L1_H2048_A32_F8192"
 
 
 
 
 
 
virtex/configs/width_ablations/bicaptioning_R_50_L1_H512.yaml DELETED
@@ -1,5 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- MODEL:
4
- TEXTUAL:
5
- NAME: "transdec_postnorm::L1_H512_A8_F2048"
 
 
 
 
 
 
virtex/configs/width_ablations/bicaptioning_R_50_L1_H768.yaml DELETED
@@ -1,5 +0,0 @@
1
- _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
2
-
3
- MODEL:
4
- TEXTUAL:
5
- NAME: "transdec_postnorm::L1_H768_A12_F3072"
 
 
 
 
 
 
virtex/{virtex/data β†’ data}/__init__.py RENAMED
File without changes
virtex/{virtex/data β†’ data}/datasets/captioning.py RENAMED
File without changes
virtex/{virtex/data β†’ data}/datasets/classification.py RENAMED
File without changes
virtex/{virtex/data β†’ data}/datasets/downstream.py RENAMED
File without changes
virtex/{virtex/data β†’ data}/datasets/masked_lm.py RENAMED
File without changes
virtex/{virtex/data β†’ data}/datasets/redcaps.py RENAMED
File without changes
virtex/{virtex/data β†’ data}/datasets/zero_shot.py RENAMED
File without changes
virtex/{virtex/data β†’ data}/readers.py RENAMED
File without changes
virtex/{virtex/data β†’ data}/tokenizers.py RENAMED
File without changes
virtex/{virtex/data β†’ data}/transforms.py RENAMED
File without changes
virtex/docs/Makefile DELETED
@@ -1,19 +0,0 @@
1
- # Minimal makefile for Sphinx documentation
2
- #
3
-
4
- # You can set these variables from the command line.
5
- SPHINXOPTS =
6
- SPHINXBUILD = sphinx-build
7
- SOURCEDIR = .
8
- BUILDDIR = ../../virtex-sphinx
9
-
10
- # Put it first so that "make" without argument is like "make help".
11
- help:
12
- @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
13
-
14
- .PHONY: help Makefile
15
-
16
- # Catch-all target: route all unknown targets to Sphinx using the new
17
- # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
18
- %: Makefile
19
- @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
virtex/docs/_static/custom.css DELETED
@@ -1,115 +0,0 @@
1
- body {
2
- padding: 40px 0 0 0;
3
- font-size: 12pt;
4
- font-family: Inconsolata !important;
5
- }
6
-
7
- /* Monospace everywhere */
8
- h1, h2, h3, h4, div.sphinxsidebar h1, div.sphinxsidebar h2,
9
- div.sphinxsidebar h3, div.sphinxsidebar h4, div.body h1,
10
- div.body h2, div.body h3, div.body h4, .admonition-title {
11
- font-family: monospace !important;
12
- }
13
-
14
- /* Make main content wider */
15
- div.document {
16
- margin: auto;
17
- width: 65%;
18
- }
19
-
20
- /* Make sidebar slightly wider. */
21
- div.sphinxsidebar {
22
- width: 250px;
23
- }
24
-
25
- div.bodywrapper {
26
- margin: 0 0 0 250px;
27
- }
28
-
29
- div.body {
30
- color: black;
31
- max-width: 100%
32
- }
33
-
34
- /* Darker headings */
35
- h1, h2, h3, h4, div.sphinxsidebar h1, div.sphinxsidebar h2,
36
- div.sphinxsidebar h3, div.sphinxsidebar h4, div.body h1,
37
- div.body h2, div.body h3, div.body h4 {
38
- color: black;
39
- }
40
-
41
- @media screen and (max-width: 875px) {
42
- div.sphinxsidebar {
43
- background-color: white;
44
- }
45
- }
46
-
47
- /* Darker bold words */
48
- strong {
49
- color: #252525;
50
- }
51
-
52
- /* TOC tree tag, view source link & permalink anchor styling. */
53
- div.sphinxsidebar a, .viewcode-link, a.reference {
54
- color: darkgreen;
55
- text-decoration: none;
56
- border-bottom: 1px dashed green;
57
- text-underline-position: under;
58
- }
59
- a.headerlink {
60
- color: black;
61
- }
62
-
63
- /* TOC tree tag, view source link & permalink anchor styling. */
64
- div.sphinxsidebar a:hover, .viewcode-link:hover, a.reference:hover,
65
- a.headerlink:hover {
66
- font-weight: 700;
67
- border-bottom: 1px solid green;
68
- }
69
-
70
- /* Add a light background to class signatures. */
71
- dl.class > dt:first-of-type, dl.function > dt:first-of-type,
72
- dl.method > dt:first-of-type, dl.classmethod > dt:first-of-type,
73
- dl.attribute > dt:first-of-type, dl.data > dt:first-of-type {
74
- font-size: 14pt;
75
- background-color: #d8f6e9;
76
- padding: 10px 20px 10px 10px;
77
- border: 1px solid #1b5e20;
78
- }
79
-
80
- /* Add lightgrey background to code snippets. */
81
- pre {
82
- background-color: #eeeeee !important;
83
- border: 1pt solid #999999;
84
- border-radius: 5px;
85
- }
86
-
87
- /* Dark orange-red comments in code snippets. */
88
- .highlight .c1 {
89
- color: #dd4533;
90
- }
91
-
92
- .admonition, .note {
93
- background-color: #fed8b1 !important;
94
- border: 1pt solid #ff7700;
95
- border-radius: 5px;
96
- }
97
-
98
- /* Make "Parameters" subsection wider - display heading and content vertically. */
99
- dl.field-list {
100
- display: block;
101
- }
102
-
103
- /* Increase font size of subsection headings ("Parameters", "Examples" etc.) */
104
- .rubric, dl.field-list > dt.field-odd, dl.field-list > dt.field-even {
105
- color: black;
106
- font-size: 18pt;
107
- font-weight: bold;
108
- padding: 0px;
109
- margin: 20px 0px 20px 0px;
110
- }
111
-
112
- /* Add margins around methods and properties. */
113
- .py {
114
- margin: 20px 0px 20px 0px;
115
- }