stevengrove
initial commit
186701e
raw
history blame
8.03 kB
# Copyright (c) Tencent Inc. All rights reserved.
import itertools
from typing import List, Sequence, Tuple
import torch
from torch import Tensor
from torch.nn.modules.batchnorm import _BatchNorm
from mmengine.model import BaseModule
from mmyolo.registry import MODELS
from mmdet.utils import OptMultiConfig, ConfigType
from transformers import (
AutoTokenizer,
AutoModel,
CLIPTextConfig)
from transformers import CLIPTextModelWithProjection as CLIPTP
@MODELS.register_module()
class HuggingVisionBackbone(BaseModule):
def __init__(self,
model_name: str,
out_indices: Sequence[int] = (0, 1, 2, 3),
norm_eval: bool = True,
frozen_modules: Sequence[str] = (),
init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg)
self.norm_eval = norm_eval
self.frozen_modules = frozen_modules
self.model = AutoModel.from_pretrained(model_name)
self._freeze_modules()
def forward(self, image: Tensor) -> Tuple[Tensor]:
encoded_dict = self.image_model(pixel_values=image,
output_hidden_states=True)
hidden_states = encoded_dict.hidden_states
img_feats = encoded_dict.get('reshaped_hidden_states', hidden_states)
img_feats = [img_feats[i] for i in self.image_out_indices]
return tuple(img_feats)
def _freeze_modules(self):
for name, module in self.model.named_modules():
for frozen_name in self.frozen_modules:
if name.startswith(frozen_name):
module.eval()
for param in module.parameters():
param.requires_grad = False
break
def train(self, mode=True):
super().train(mode)
self._freeze_modules()
if mode and self.norm_eval:
for m in self.modules():
# trick: eval have effect on BatchNorm only
if isinstance(m, _BatchNorm):
m.eval()
@MODELS.register_module()
class HuggingCLIPLanguageBackbone(BaseModule):
def __init__(self,
model_name: str,
frozen_modules: Sequence[str] = (),
dropout: float = 0.0,
training_use_cache: bool = False,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg)
self.frozen_modules = frozen_modules
self.training_use_cache = training_use_cache
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
clip_config = CLIPTextConfig.from_pretrained(model_name,
attention_dropout=dropout)
self.model = CLIPTP.from_pretrained(model_name,
config=clip_config)
self._freeze_modules()
def forward_cache(self, text: List[List[str]]) -> Tensor:
if not hasattr(self, "cache"):
self.cache = self.forward_text(text)
return self.cache
def forward(self, text: List[List[str]]) -> Tensor:
if self.training:
return self.forward_text(text)
else:
return self.forward_text(text)
# return self.forward_cache(text)
def forward_tokenizer(self, texts):
if not hasattr(self, 'text'):
text = list(itertools.chain(*texts))
# print(text)
# # text = ['a photo of {}'.format(x) for x in text]
text = self.tokenizer(text=text, return_tensors='pt', padding=True)
# print(text)
self.text = text.to(device=self.model.device)
return self.text
def forward_text(self, text: List[List[str]]) -> Tensor:
num_per_batch = [len(t) for t in text]
assert max(num_per_batch) == min(num_per_batch), (
'number of sequences not equal in batch')
# print(max([[len(t.split(' ')) for t in tt] for tt in text]))
# print(num_per_batch, max(num_per_batch))
text = list(itertools.chain(*text))
# print(text)
# text = ['a photo of {}'.format(x) for x in text]
# text = self.forward_tokenizer(text)
text = self.tokenizer(text=text, return_tensors='pt', padding=True)
text = text.to(device=self.model.device)
txt_outputs = self.model(**text)
# txt_feats = txt_outputs.last_hidden_state[:, 0, :]
txt_feats = txt_outputs.text_embeds
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
txt_feats = txt_feats.reshape(-1, num_per_batch[0],
txt_feats.shape[-1])
return txt_feats
def _freeze_modules(self):
if len(self.frozen_modules) == 0:
# not freeze
return
if self.frozen_modules[0] == "all":
self.model.eval()
for _, module in self.model.named_modules():
module.eval()
for param in module.parameters():
param.requires_grad = False
return
for name, module in self.model.named_modules():
for frozen_name in self.frozen_modules:
if name.startswith(frozen_name):
module.eval()
for param in module.parameters():
param.requires_grad = False
break
def train(self, mode=True):
super().train(mode)
self._freeze_modules()
@MODELS.register_module()
class PseudoLanguageBackbone(BaseModule):
"""Pseudo Language Backbone
Args:
text_embed_path (str): path to the text embedding file
"""
def __init__(self,
text_embed_path: str = "",
test_embed_path: str = None,
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg)
# {text:embed}
self.text_embed = torch.load(text_embed_path, map_location='cpu')
if test_embed_path is None:
self.test_embed = self.text_embed
else:
self.test_embed = torch.load(test_embed_path)
self.register_buffer("buff", torch.zeros([
1,
]))
def forward_cache(self, text: List[List[str]]) -> Tensor:
if not hasattr(self, "cache"):
self.cache = self.forward_text(text)
return self.cache
def forward(self, text: List[List[str]]) -> Tensor:
if self.training:
return self.forward_text(text)
else:
return self.forward_cache(text)
def forward_text(self, text: List[List[str]]) -> Tensor:
num_per_batch = [len(t) for t in text]
assert max(num_per_batch) == min(num_per_batch), (
'number of sequences not equal in batch')
text = list(itertools.chain(*text))
if self.training:
text_embed_dict = self.text_embed
else:
text_embed_dict = self.test_embed
text_embeds = torch.stack(
[text_embed_dict[x.split("/")[0]] for x in text])
# requires no grad and force to float
text_embeds = text_embeds.to(
self.buff.device).requires_grad_(False).float()
text_embeds = text_embeds.reshape(-1, num_per_batch[0],
text_embeds.shape[-1])
return text_embeds
@MODELS.register_module()
class MultiModalYOLOBackbone(BaseModule):
def __init__(self,
image_model: ConfigType,
text_model: ConfigType,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg)
self.image_model = MODELS.build(image_model)
self.text_model = MODELS.build(text_model)
def forward(self, image: Tensor,
text: List[List[str]]) -> Tuple[Tuple[Tensor], Tensor]:
img_feats = self.image_model(image)
txt_feats = self.text_model(text)
return img_feats, txt_feats