designvlm / build_mlp.py
Yosemat's picture
Upload folder using huggingface_hub
8eeaa19 verified
import math
import re
import torch
import torch.nn as nn
from transformers import CLIPVisionModel
def build_vision_tower():
vision_tower = "internlm/internlm-xcomposer2d5-clip"
return CLIPVisionTower(vision_tower)
def build_vision_projector():
projector_type = "mlp2x_gelu"
mm_hidden_size = 4096
mid_hidden_size = 4096
hidden_size = 4096
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(mm_hidden_size, mid_hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(mid_hidden_size, mid_hidden_size))
return nn.Sequential(*modules)
if projector_type == "identity":
return IdentityMap()
raise ValueError(f"Unknown projector type: {projector_type}")
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": "identity"}
class CLIPVisionTower(nn.Module):
def __init__(self, vision_tower):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
# self.conv_dim = 8192
# self.conv = torch.nn.Conv2d(1024, self.conv_dim,3,2,1)
self.select_layer = -1
self.select_feature = "patch"
self.load_model()
def load_model(self):
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def resize_pos(self):
print("Dummy Resized")
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == "patch":
image_features = image_features[:, 1:]
elif self.select_feature == "cls_patch":
image_features = image_features
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
return image_features
def forward(self, images, glb_GN, sub_GN) -> tuple[torch.Tensor, list[int]]:
if not self.is_loaded:
self.load_model()
assert type(images) is list
shapes = []
input_imgs = []
for img in images:
_, C, H, W = img.shape
shapes.append([H // 560, W // 560])
sub_img = (
img.reshape(1, 3, H // 560, 560, W // 560, 560)
.permute(0, 2, 4, 1, 3, 5)
.reshape(-1, 3, 560, 560)
.contiguous()
)
glb_img = torch.nn.functional.interpolate(
img.float(),
size=(560, 560),
mode="bicubic",
).to(sub_img.dtype)
input_imgs.append(glb_img)
input_imgs.append(sub_img)
input_imgs = torch.cat(input_imgs, dim=0)
image_forward_outs = self.vision_tower(
input_imgs.to(device=self.device, dtype=self.dtype),
output_hidden_states=True,
)
image_features = self.feature_select(image_forward_outs).to(
input_imgs.dtype
) ### B*?, N, C
_, N, C = image_features.shape
H = int(math.sqrt(N))
assert N == 40**2
output_imgs = []
output_len = []
for [h, w] in shapes:
B_ = h * w
glb_img = image_features[:1] ### 1, N, C
glb_img = (
glb_img.reshape(1, H, H, C)
.reshape(1, H // 2, 2, H // 2, 2, C)
.contiguous()
.permute(0, 1, 3, 2, 4, 5)
.reshape(1, H // 2, H // 2, 4 * C)
.contiguous()
)
temp_glb_GN = sub_GN.repeat(1, H // 2, 1, 1)
glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(1, -1, 4 * C)
sub_img = image_features[1 : 1 + B_] ### ?, N, C
sub_img = (
sub_img.reshape(B_, H, H, C)
.reshape(B_, H // 2, 2, H // 2, 2, C)
.contiguous()
.permute(0, 1, 3, 2, 4, 5)
.reshape(B_, -1, 4 * C)
.contiguous()
)
sub_img = (
sub_img.reshape(1, h, w, 20, 20, -1)
.permute(0, 1, 3, 2, 4, 5)
.reshape(1, h * 20, w * 20, 4 * C)
)
temp_sub_GN = sub_GN.repeat(1, h * 20, 1, 1)
sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(1, -1, 4 * C)
output_imgs.append(torch.cat([glb_img, glb_GN, sub_img], dim=1))
temp_len = int((h * w + 1) * 400 + 1 + (h + 1) * 20)
assert temp_len == output_imgs[-1].shape[1]
output_len.append(temp_len)
image_features = image_features[1 + h * w :]
output_imgs = torch.cat(output_imgs, dim=1)
return output_imgs, output_len
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
class PLoRA(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
lora_r=8,
lora_alpha=16,
lora_dropout=0.05,
lora_len=0,
**kwargs,
) -> None:
super().__init__(in_features, out_features, bias, device, dtype)
self.lora_r = lora_r
self.lora_alpha = lora_alpha
self.lora_len = lora_len
if lora_dropout > 0.0:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
self.lora_scaling = self.lora_alpha / self.lora_r
self.Plora_A = nn.Linear(
in_features, self.lora_r, bias=False, device=device, dtype=dtype
)
self.Plora_B = nn.Linear(
self.lora_r, out_features, bias=False, device=device, dtype=dtype
)
self.lora_sft_A = nn.Linear(
in_features, 256, bias=False, device=device, dtype=dtype
)
self.lora_sft_B = nn.Linear(
256, out_features, bias=False, device=device, dtype=dtype
)
self.lora_dpo_A = nn.Linear(
in_features, 256, bias=False, device=device, dtype=dtype
)
self.lora_dpo_B = nn.Linear(
256, out_features, bias=False, device=device, dtype=dtype
)
self.lora_web_A = nn.Linear(
in_features, 512, bias=False, device=device, dtype=dtype
)
self.lora_web_B = nn.Linear(
512, out_features, bias=False, device=device, dtype=dtype
)
self.reset_parameters()
def reset_parameters(self):
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_B.weight)
# print ("lora weight init {} {}".format(torch.mean(self.lora_A.weight), torch.mean(self.lora_B.weight)))
def forward(self, x, im_mask=None, infer_mode="base"):
B, N, C = x.shape
im_mask = im_mask.view(-1)
x = x.reshape(-1, C)
res = super().forward(x)
if infer_mode == "web":
res += self.lora_web_B(self.lora_web_A(x))
elif infer_mode == "write":
res += self.lora_sft_B(self.lora_sft_A(x))
res += self.lora_dpo_B(self.lora_dpo_A(x))
else:
pass
if im_mask is not None:
if torch.sum(im_mask) > 0:
part_x = x[im_mask]
res[im_mask] += (
self.Plora_B(self.Plora_A(self.lora_dropout(part_x)))
* self.lora_scaling
)
else:
part_x = x[:1]
res[:1] += self.Plora_B(self.Plora_A(self.lora_dropout(part_x))) * 0
return res.reshape(B, N, -1)