Vision-CAIR commited on
Commit
7a02f27
1 Parent(s): dc2b1a4

Upload folder using huggingface_hub

Browse files
__init__.py CHANGED
@@ -11,16 +11,22 @@ from omegaconf import OmegaConf
11
 
12
  from minigpt4_video.registry import registry
13
  from minigpt4_video.base_model import BaseModel
14
- from minigpt4_video.blip2 import Blip2Base
15
  from minigpt4_video.base_processor import BaseProcessor
 
 
 
 
 
16
  from minigpt4_video.mini_gpt4_llama_v2 import MiniGPT4_Video
17
 
18
 
 
19
  __all__ = [
20
  "load_model",
21
  "BaseModel",
22
  "Blip2Base",
23
  "MiniGPT4_Video",
 
24
  ]
25
 
26
 
 
11
 
12
  from minigpt4_video.registry import registry
13
  from minigpt4_video.base_model import BaseModel
 
14
  from minigpt4_video.base_processor import BaseProcessor
15
+ from minigpt4_video.blip_processors import *
16
+ from minigpt4_video.blip2 import Blip2Base
17
+ from minigpt4_video.clip_vision_encoder import *
18
+ from minigpt4_video.config import *
19
+ from minigpt4_video.eva_vit import *
20
  from minigpt4_video.mini_gpt4_llama_v2 import MiniGPT4_Video
21
 
22
 
23
+
24
  __all__ = [
25
  "load_model",
26
  "BaseModel",
27
  "Blip2Base",
28
  "MiniGPT4_Video",
29
+
30
  ]
31
 
32
 
__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/__pycache__/__init__.cpython-310.pyc and b/__pycache__/__init__.cpython-310.pyc differ
 
__pycache__/blip_processors.cpython-310.pyc ADDED
Binary file (4.36 kB). View file
 
__pycache__/clip_vision_encoder.cpython-310.pyc ADDED
Binary file (2.97 kB). View file
 
__pycache__/config.cpython-310.pyc ADDED
Binary file (12.3 kB). View file
 
__pycache__/mini_gpt4_llama_v2.cpython-310.pyc CHANGED
Binary files a/__pycache__/mini_gpt4_llama_v2.cpython-310.pyc and b/__pycache__/mini_gpt4_llama_v2.cpython-310.pyc differ
 
__pycache__/randaugment.cpython-310.pyc ADDED
Binary file (12.1 kB). View file
 
__pycache__/registry.cpython-310.pyc CHANGED
Binary files a/__pycache__/registry.cpython-310.pyc and b/__pycache__/registry.cpython-310.pyc differ
 
blip_processors.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import re
9
+
10
+ from minigpt4_video.registry import registry
11
+ from minigpt4_video.base_processor import BaseProcessor
12
+ from minigpt4_video.randaugment import RandomAugment
13
+ from omegaconf import OmegaConf
14
+ from torchvision import transforms
15
+ from torchvision.transforms.functional import InterpolationMode
16
+
17
+
18
+ class BlipImageBaseProcessor(BaseProcessor):
19
+ def __init__(self, mean=None, std=None):
20
+ if mean is None:
21
+ mean = (0.48145466, 0.4578275, 0.40821073)
22
+ if std is None:
23
+ std = (0.26862954, 0.26130258, 0.27577711)
24
+
25
+
26
+ segment_mean = (0.485, 0.456, 0.406)
27
+ segment_std = (0.229, 0.224, 0.225)
28
+
29
+ self.normalize = transforms.Normalize(segment_mean, segment_std)
30
+
31
+
32
+ @registry.register_processor("blip_caption")
33
+ class BlipCaptionProcessor(BaseProcessor):
34
+ def __init__(self, prompt="", max_words=50):
35
+ self.prompt = prompt
36
+ self.max_words = max_words
37
+
38
+ def __call__(self, caption):
39
+ caption = self.prompt + self.pre_caption(caption)
40
+
41
+ return caption
42
+
43
+ @classmethod
44
+ def from_config(cls, cfg=None):
45
+ if cfg is None:
46
+ cfg = OmegaConf.create()
47
+
48
+ prompt = cfg.get("prompt", "")
49
+ max_words = cfg.get("max_words", 50)
50
+
51
+ return cls(prompt=prompt, max_words=max_words)
52
+
53
+ def pre_caption(self, caption):
54
+ caption = re.sub(
55
+ r"([.!\"()*#:;~])",
56
+ " ",
57
+ caption.lower(),
58
+ )
59
+ caption = re.sub(
60
+ r"\s{2,}",
61
+ " ",
62
+ caption,
63
+ )
64
+ caption = caption.rstrip("\n")
65
+ caption = caption.strip(" ")
66
+
67
+ # truncate caption
68
+ caption_words = caption.split(" ")
69
+ if len(caption_words) > self.max_words:
70
+ caption = " ".join(caption_words[: self.max_words])
71
+
72
+ return caption
73
+
74
+
75
+ @registry.register_processor("blip2_image_train")
76
+ class Blip2ImageTrainProcessor(BlipImageBaseProcessor):
77
+ def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
78
+ super().__init__(mean=mean, std=std)
79
+
80
+ # self.transform = transforms.Compose(
81
+ # [
82
+ # transforms.RandomResizedCrop(
83
+ # image_size,
84
+ # scale=(min_scale, max_scale),
85
+ # interpolation=InterpolationMode.BICUBIC,
86
+ # ),
87
+ # transforms.ToTensor(),
88
+ # self.normalize,
89
+ # ]
90
+ # )
91
+ self.transform = transforms.Compose([
92
+ transforms.Resize(
93
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
94
+ ),
95
+ transforms.ToTensor(),
96
+ self.normalize,
97
+ ]
98
+ )
99
+
100
+ # ### segment anything
101
+ # '''
102
+ # x = (x - self.pixel_mean) / self.pixel_std
103
+
104
+ # # Pad
105
+ # h, w = x.shape[-2:]
106
+ # padh = self.image_encoder.img_size - h
107
+ # padw = self.image_encoder.img_size - w
108
+ # x = F.pad(x, (0, padw, 0, padh))
109
+ # '''
110
+
111
+ def __call__(self, item):
112
+ return self.transform(item)
113
+
114
+ @classmethod
115
+ def from_config(cls, cfg=None):
116
+ if cfg is None:
117
+ cfg = OmegaConf.create()
118
+
119
+ image_size = cfg.get("image_size", 224)
120
+
121
+ mean = cfg.get("mean", None)
122
+ std = cfg.get("std", None)
123
+
124
+ min_scale = cfg.get("min_scale", 0.5)
125
+ max_scale = cfg.get("max_scale", 1.0)
126
+
127
+ return cls(
128
+ image_size=image_size,
129
+ mean=mean,
130
+ std=std,
131
+ min_scale=min_scale,
132
+ max_scale=max_scale,
133
+ )
134
+
135
+
136
+ @registry.register_processor("blip2_image_eval")
137
+ class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
138
+ def __init__(self, image_size=224, mean=None, std=None):
139
+ super().__init__(mean=mean, std=std)
140
+
141
+ self.transform = transforms.Compose(
142
+ [
143
+ transforms.Resize(
144
+ (image_size, image_size), interpolation=InterpolationMode.BICUBIC
145
+ ),
146
+ transforms.ToTensor(),
147
+ self.normalize,
148
+ ]
149
+ )
150
+
151
+ def __call__(self, item):
152
+ return self.transform(item)
153
+
154
+ @classmethod
155
+ def from_config(cls, cfg=None):
156
+ if cfg is None:
157
+ cfg = OmegaConf.create()
158
+
159
+ image_size = cfg.get("image_size", 224)
160
+
161
+ mean = cfg.get("mean", None)
162
+ std = cfg.get("std", None)
163
+
164
+ return cls(image_size=image_size, mean=mean, std=std)
randaugment.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ import torch
12
+
13
+
14
+ ## aug functions
15
+ def identity_func(img):
16
+ return img
17
+
18
+
19
+ def autocontrast_func(img, cutoff=0):
20
+ """
21
+ same output as PIL.ImageOps.autocontrast
22
+ """
23
+ n_bins = 256
24
+
25
+ def tune_channel(ch):
26
+ n = ch.size
27
+ cut = cutoff * n // 100
28
+ if cut == 0:
29
+ high, low = ch.max(), ch.min()
30
+ else:
31
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
32
+ low = np.argwhere(np.cumsum(hist) > cut)
33
+ low = 0 if low.shape[0] == 0 else low[0]
34
+ high = np.argwhere(np.cumsum(hist[::-1]) > cut)
35
+ high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
36
+ if high <= low:
37
+ table = np.arange(n_bins)
38
+ else:
39
+ scale = (n_bins - 1) / (high - low)
40
+ offset = -low * scale
41
+ table = np.arange(n_bins) * scale + offset
42
+ table[table < 0] = 0
43
+ table[table > n_bins - 1] = n_bins - 1
44
+ table = table.clip(0, 255).astype(np.uint8)
45
+ return table[ch]
46
+
47
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
48
+ out = cv2.merge(channels)
49
+ return out
50
+
51
+
52
+ def equalize_func(img):
53
+ """
54
+ same output as PIL.ImageOps.equalize
55
+ PIL's implementation is different from cv2.equalize
56
+ """
57
+ n_bins = 256
58
+
59
+ def tune_channel(ch):
60
+ hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
61
+ non_zero_hist = hist[hist != 0].reshape(-1)
62
+ step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
63
+ if step == 0:
64
+ return ch
65
+ n = np.empty_like(hist)
66
+ n[0] = step // 2
67
+ n[1:] = hist[:-1]
68
+ table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
69
+ return table[ch]
70
+
71
+ channels = [tune_channel(ch) for ch in cv2.split(img)]
72
+ out = cv2.merge(channels)
73
+ return out
74
+
75
+
76
+ def rotate_func(img, degree, fill=(0, 0, 0)):
77
+ """
78
+ like PIL, rotate by degree, not radians
79
+ """
80
+ H, W = img.shape[0], img.shape[1]
81
+ center = W / 2, H / 2
82
+ M = cv2.getRotationMatrix2D(center, degree, 1)
83
+ out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
84
+ return out
85
+
86
+
87
+ def solarize_func(img, thresh=128):
88
+ """
89
+ same output as PIL.ImageOps.posterize
90
+ """
91
+ table = np.array([el if el < thresh else 255 - el for el in range(256)])
92
+ table = table.clip(0, 255).astype(np.uint8)
93
+ out = table[img]
94
+ return out
95
+
96
+
97
+ def color_func(img, factor):
98
+ """
99
+ same output as PIL.ImageEnhance.Color
100
+ """
101
+ ## implementation according to PIL definition, quite slow
102
+ # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
103
+ # out = blend(degenerate, img, factor)
104
+ # M = (
105
+ # np.eye(3) * factor
106
+ # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
107
+ # )[np.newaxis, np.newaxis, :]
108
+ M = np.float32(
109
+ [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
110
+ ) * factor + np.float32([[0.114], [0.587], [0.299]])
111
+ out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
112
+ return out
113
+
114
+
115
+ def contrast_func(img, factor):
116
+ """
117
+ same output as PIL.ImageEnhance.Contrast
118
+ """
119
+ mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
120
+ table = (
121
+ np.array([(el - mean) * factor + mean for el in range(256)])
122
+ .clip(0, 255)
123
+ .astype(np.uint8)
124
+ )
125
+ out = table[img]
126
+ return out
127
+
128
+
129
+ def brightness_func(img, factor):
130
+ """
131
+ same output as PIL.ImageEnhance.Contrast
132
+ """
133
+ table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
134
+ out = table[img]
135
+ return out
136
+
137
+
138
+ def sharpness_func(img, factor):
139
+ """
140
+ The differences the this result and PIL are all on the 4 boundaries, the center
141
+ areas are same
142
+ """
143
+ kernel = np.ones((3, 3), dtype=np.float32)
144
+ kernel[1][1] = 5
145
+ kernel /= 13
146
+ degenerate = cv2.filter2D(img, -1, kernel)
147
+ if factor == 0.0:
148
+ out = degenerate
149
+ elif factor == 1.0:
150
+ out = img
151
+ else:
152
+ out = img.astype(np.float32)
153
+ degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
154
+ out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
155
+ out = out.astype(np.uint8)
156
+ return out
157
+
158
+
159
+ def shear_x_func(img, factor, fill=(0, 0, 0)):
160
+ H, W = img.shape[0], img.shape[1]
161
+ M = np.float32([[1, factor, 0], [0, 1, 0]])
162
+ out = cv2.warpAffine(
163
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
164
+ ).astype(np.uint8)
165
+ return out
166
+
167
+
168
+ def translate_x_func(img, offset, fill=(0, 0, 0)):
169
+ """
170
+ same output as PIL.Image.transform
171
+ """
172
+ H, W = img.shape[0], img.shape[1]
173
+ M = np.float32([[1, 0, -offset], [0, 1, 0]])
174
+ out = cv2.warpAffine(
175
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
176
+ ).astype(np.uint8)
177
+ return out
178
+
179
+
180
+ def translate_y_func(img, offset, fill=(0, 0, 0)):
181
+ """
182
+ same output as PIL.Image.transform
183
+ """
184
+ H, W = img.shape[0], img.shape[1]
185
+ M = np.float32([[1, 0, 0], [0, 1, -offset]])
186
+ out = cv2.warpAffine(
187
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
188
+ ).astype(np.uint8)
189
+ return out
190
+
191
+
192
+ def posterize_func(img, bits):
193
+ """
194
+ same output as PIL.ImageOps.posterize
195
+ """
196
+ out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
197
+ return out
198
+
199
+
200
+ def shear_y_func(img, factor, fill=(0, 0, 0)):
201
+ H, W = img.shape[0], img.shape[1]
202
+ M = np.float32([[1, 0, 0], [factor, 1, 0]])
203
+ out = cv2.warpAffine(
204
+ img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
205
+ ).astype(np.uint8)
206
+ return out
207
+
208
+
209
+ def cutout_func(img, pad_size, replace=(0, 0, 0)):
210
+ replace = np.array(replace, dtype=np.uint8)
211
+ H, W = img.shape[0], img.shape[1]
212
+ rh, rw = np.random.random(2)
213
+ pad_size = pad_size // 2
214
+ ch, cw = int(rh * H), int(rw * W)
215
+ x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
216
+ y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
217
+ out = img.copy()
218
+ out[x1:x2, y1:y2, :] = replace
219
+ return out
220
+
221
+
222
+ ### level to args
223
+ def enhance_level_to_args(MAX_LEVEL):
224
+ def level_to_args(level):
225
+ return ((level / MAX_LEVEL) * 1.8 + 0.1,)
226
+
227
+ return level_to_args
228
+
229
+
230
+ def shear_level_to_args(MAX_LEVEL, replace_value):
231
+ def level_to_args(level):
232
+ level = (level / MAX_LEVEL) * 0.3
233
+ if np.random.random() > 0.5:
234
+ level = -level
235
+ return (level, replace_value)
236
+
237
+ return level_to_args
238
+
239
+
240
+ def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
241
+ def level_to_args(level):
242
+ level = (level / MAX_LEVEL) * float(translate_const)
243
+ if np.random.random() > 0.5:
244
+ level = -level
245
+ return (level, replace_value)
246
+
247
+ return level_to_args
248
+
249
+
250
+ def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
251
+ def level_to_args(level):
252
+ level = int((level / MAX_LEVEL) * cutout_const)
253
+ return (level, replace_value)
254
+
255
+ return level_to_args
256
+
257
+
258
+ def solarize_level_to_args(MAX_LEVEL):
259
+ def level_to_args(level):
260
+ level = int((level / MAX_LEVEL) * 256)
261
+ return (level,)
262
+
263
+ return level_to_args
264
+
265
+
266
+ def none_level_to_args(level):
267
+ return ()
268
+
269
+
270
+ def posterize_level_to_args(MAX_LEVEL):
271
+ def level_to_args(level):
272
+ level = int((level / MAX_LEVEL) * 4)
273
+ return (level,)
274
+
275
+ return level_to_args
276
+
277
+
278
+ def rotate_level_to_args(MAX_LEVEL, replace_value):
279
+ def level_to_args(level):
280
+ level = (level / MAX_LEVEL) * 30
281
+ if np.random.random() < 0.5:
282
+ level = -level
283
+ return (level, replace_value)
284
+
285
+ return level_to_args
286
+
287
+
288
+ func_dict = {
289
+ "Identity": identity_func,
290
+ "AutoContrast": autocontrast_func,
291
+ "Equalize": equalize_func,
292
+ "Rotate": rotate_func,
293
+ "Solarize": solarize_func,
294
+ "Color": color_func,
295
+ "Contrast": contrast_func,
296
+ "Brightness": brightness_func,
297
+ "Sharpness": sharpness_func,
298
+ "ShearX": shear_x_func,
299
+ "TranslateX": translate_x_func,
300
+ "TranslateY": translate_y_func,
301
+ "Posterize": posterize_func,
302
+ "ShearY": shear_y_func,
303
+ }
304
+
305
+ translate_const = 10
306
+ MAX_LEVEL = 10
307
+ replace_value = (128, 128, 128)
308
+ arg_dict = {
309
+ "Identity": none_level_to_args,
310
+ "AutoContrast": none_level_to_args,
311
+ "Equalize": none_level_to_args,
312
+ "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
313
+ "Solarize": solarize_level_to_args(MAX_LEVEL),
314
+ "Color": enhance_level_to_args(MAX_LEVEL),
315
+ "Contrast": enhance_level_to_args(MAX_LEVEL),
316
+ "Brightness": enhance_level_to_args(MAX_LEVEL),
317
+ "Sharpness": enhance_level_to_args(MAX_LEVEL),
318
+ "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
319
+ "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
320
+ "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
321
+ "Posterize": posterize_level_to_args(MAX_LEVEL),
322
+ "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
323
+ }
324
+
325
+
326
+ class RandomAugment(object):
327
+ def __init__(self, N=2, M=10, isPIL=False, augs=[]):
328
+ self.N = N
329
+ self.M = M
330
+ self.isPIL = isPIL
331
+ if augs:
332
+ self.augs = augs
333
+ else:
334
+ self.augs = list(arg_dict.keys())
335
+
336
+ def get_random_ops(self):
337
+ sampled_ops = np.random.choice(self.augs, self.N)
338
+ return [(op, 0.5, self.M) for op in sampled_ops]
339
+
340
+ def __call__(self, img):
341
+ if self.isPIL:
342
+ img = np.array(img)
343
+ ops = self.get_random_ops()
344
+ for name, prob, level in ops:
345
+ if np.random.random() > prob:
346
+ continue
347
+ args = arg_dict[name](level)
348
+ img = func_dict[name](img, *args)
349
+ return img
350
+
351
+
352
+ class VideoRandomAugment(object):
353
+ def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
354
+ self.N = N
355
+ self.M = M
356
+ self.p = p
357
+ self.tensor_in_tensor_out = tensor_in_tensor_out
358
+ if augs:
359
+ self.augs = augs
360
+ else:
361
+ self.augs = list(arg_dict.keys())
362
+
363
+ def get_random_ops(self):
364
+ sampled_ops = np.random.choice(self.augs, self.N, replace=False)
365
+ return [(op, self.M) for op in sampled_ops]
366
+
367
+ def __call__(self, frames):
368
+ assert (
369
+ frames.shape[-1] == 3
370
+ ), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
371
+
372
+ if self.tensor_in_tensor_out:
373
+ frames = frames.numpy().astype(np.uint8)
374
+
375
+ num_frames = frames.shape[0]
376
+
377
+ ops = num_frames * [self.get_random_ops()]
378
+ apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
379
+
380
+ frames = torch.stack(
381
+ list(map(self._aug, frames, ops, apply_or_not)), dim=0
382
+ ).float()
383
+
384
+ return frames
385
+
386
+ def _aug(self, img, ops, apply_or_not):
387
+ for i, (name, level) in enumerate(ops):
388
+ if not apply_or_not[i]:
389
+ continue
390
+ args = arg_dict[name](level)
391
+ img = func_dict[name](img, *args)
392
+ return torch.from_numpy(img)
393
+
394
+
395
+ if __name__ == "__main__":
396
+ a = RandomAugment()
397
+ img = np.random.randn(32, 32, 3)
398
+ a(img)
registry.py CHANGED
@@ -98,12 +98,12 @@ class Registry:
98
  # model_cls, BaseModel
99
  # ), "All models must inherit BaseModel class"
100
 
101
- if name in cls.mapping["model_name_mapping"]:
102
- raise KeyError(
103
- "Name '{}' already registered for {}.".format(
104
- name, cls.mapping["model_name_mapping"][name]
105
- )
106
- )
107
  cls.mapping["model_name_mapping"][name] = model_cls
108
  return model_cls
109
 
@@ -124,15 +124,15 @@ class Registry:
124
  def wrap(processor_cls):
125
  from minigpt4.processors import BaseProcessor
126
 
127
- assert issubclass(
128
- processor_cls, BaseProcessor
129
- ), "All processors must inherit BaseProcessor class"
130
- if name in cls.mapping["processor_name_mapping"]:
131
- raise KeyError(
132
- "Name '{}' already registered for {}.".format(
133
- name, cls.mapping["processor_name_mapping"][name]
134
- )
135
- )
136
  cls.mapping["processor_name_mapping"][name] = processor_cls
137
  return processor_cls
138
 
 
98
  # model_cls, BaseModel
99
  # ), "All models must inherit BaseModel class"
100
 
101
+ # if name in cls.mapping["model_name_mapping"]:
102
+ # raise KeyError(
103
+ # "Name '{}' already registered for {}.".format(
104
+ # name, cls.mapping["model_name_mapping"][name]
105
+ # )
106
+ # )
107
  cls.mapping["model_name_mapping"][name] = model_cls
108
  return model_cls
109
 
 
124
  def wrap(processor_cls):
125
  from minigpt4.processors import BaseProcessor
126
 
127
+ # assert issubclass(
128
+ # processor_cls, BaseProcessor
129
+ # ), "All processors must inherit BaseProcessor class"
130
+ # if name in cls.mapping["processor_name_mapping"]:
131
+ # raise KeyError(
132
+ # "Name '{}' already registered for {}.".format(
133
+ # name, cls.mapping["processor_name_mapping"][name]
134
+ # )
135
+ # )
136
  cls.mapping["processor_name_mapping"][name] = processor_cls
137
  return processor_cls
138