Vision-CAIR commited on
Commit
9942389
1 Parent(s): 5a5f255

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. mini_gpt4_llama_v2.py +2 -1
  2. registry.py +1 -0
  3. utils.py +1 -1
mini_gpt4_llama_v2.py CHANGED
@@ -102,7 +102,8 @@ class MiniGPT4_Video(Blip2Base, PreTrainedModel):
102
  Blip2Base.__init__(self)
103
 
104
  vis_processor_cfg = {"name": "blip2_image_train","image_size": 224}
105
- self.vis_processor = registry.get_processor_class(vis_processor_cfg["name"]).from_config(vis_processor_cfg)
 
106
  self.CONV_VISION = CONV_VISION
107
  if "Mistral" in self.llama_model:
108
  from .modeling_mistral import MistralForCausalLM as llm_model
 
102
  Blip2Base.__init__(self)
103
 
104
  vis_processor_cfg = {"name": "blip2_image_train","image_size": 224}
105
+ self.vis_processor = registry.get_processor_class(vis_processor_cfg["name"])
106
+ self.vis_processor = self.vis_processor.from_config(vis_processor_cfg)
107
  self.CONV_VISION = CONV_VISION
108
  if "Mistral" in self.llama_model:
109
  from .modeling_mistral import MistralForCausalLM as llm_model
registry.py CHANGED
@@ -243,6 +243,7 @@ class Registry:
243
 
244
  @classmethod
245
  def get_processor_class(cls, name):
 
246
  return cls.mapping["processor_name_mapping"].get(name, None)
247
 
248
  @classmethod
 
243
 
244
  @classmethod
245
  def get_processor_class(cls, name):
246
+ print(cls.mapping["processor_name_mapping"])
247
  return cls.mapping["processor_name_mapping"].get(name, None)
248
 
249
  @classmethod
utils.py CHANGED
@@ -23,7 +23,7 @@ import pandas as pd
23
  import yaml
24
  from iopath.common.download import download
25
  from iopath.common.file_io import file_lock, g_pathmgr
26
- from .registry import registry
27
  from torch.utils.model_zoo import tqdm
28
  from torchvision.datasets.utils import (
29
  check_integrity,
 
23
  import yaml
24
  from iopath.common.download import download
25
  from iopath.common.file_io import file_lock, g_pathmgr
26
+ from minigpt4_video.registry import registry
27
  from torch.utils.model_zoo import tqdm
28
  from torchvision.datasets.utils import (
29
  check_integrity,