import torch import torch.nn as nn from mmengine.model import BaseModel from xtuner.registry import BUILDER from xtuner.model.utils import get_peft_model_state_dict class LisaModel(BaseModel): def __init__(self, mllm, tokenizer, grounding_encoder, loss_mask=None, loss_dice=None,): super(LisaModel, self).__init__() self.mllm = BUILDER.build(mllm) if self.mllm.use_llm_lora: self.mllm.model.language_model.base_model.model.lm_head.requires_grad_(True) self.mllm.model.language_model.base_model.model.model.embed_tokens.requires_grad_(True) self.tokenizer = BUILDER.build(tokenizer) self._add_special_tokens() self.grounding_encoder = BUILDER.build(grounding_encoder) self.grounding_encoder.requires_grad_(False) self.grounding_encoder.mask_decoder.requires_grad_(True) in_dim = self.mllm.model.config.llm_config.hidden_size out_dim = self.grounding_encoder.mask_decoder.transformer_dim self.text_hidden_fcs = nn.Sequential( nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True), nn.Linear(in_dim, out_dim), nn.Dropout(0.0) ) self.loss_mask = BUILDER.build(loss_mask) self.loss_dice = BUILDER.build(loss_dice) def _add_special_tokens(self): special_tokens = ['[SEG]'] num_new_tokens = self.tokenizer.add_tokens( special_tokens, special_tokens=True) if num_new_tokens > 0: self.mllm.model.language_model.resize_token_embeddings(len(self.tokenizer)) self.seg_token_idx = self.tokenizer("[SEG]", add_special_tokens=False).input_ids[0] def _generate_and_postprocess_masks(self, pred_embeddings, image_embeddings, resize_list=None, orig_size_list=None): pred_masks = [] for i, pred_embedding in enumerate(pred_embeddings): sparse_embeddings, dense_embeddings = self.grounding_encoder.prompt_encoder( points=None, boxes=None, masks=None, text_embeds=pred_embedding.unsqueeze(1) ) sparse_embeddings = sparse_embeddings.to(pred_embedding.dtype) low_res_masks, _ = self.grounding_encoder.mask_decoder( image_embeddings=image_embeddings[i].unsqueeze(0), image_pe=self.grounding_encoder.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False, ) pred_mask = self.grounding_encoder.postprocess_masks( low_res_masks, input_size=resize_list[i], original_size=orig_size_list[i], ) pred_masks.append(pred_mask[:, 0]) return pred_masks def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): return super().load_state_dict(state_dict, strict, assign) def state_dict(self, *args, **kwargs): state_dict = super().state_dict(*args, **kwargs) from collections import OrderedDict to_return = OrderedDict() # Step 1. visual_encoder if self.mllm.use_visual_encoder_lora: to_return.update( get_peft_model_state_dict( self.mllm.model.vision_model, state_dict=state_dict)) elif not self.mllm.freeze_visual_encoder: to_return.update({ k: v for k, v in state_dict.items() if 'visual_encoder.' in k }) # Step 2. LLM if self.mllm.use_llm_lora: to_return.update( get_peft_model_state_dict(self.mllm.model.language_model, state_dict=state_dict)) elif not self.mllm.freeze_llm: to_return.update( {k: v for k, v in state_dict.items() if 'llm.' in k}) # Step 3. Projector to_return.update( {k: v for k, v in state_dict.items() if 'mlp1.' in k}) to_return.update( {k: v for k, v in state_dict.items() if 'grounding_encoder.mask_decoder.' in k}) to_return.update( {k: v for k, v in state_dict.items() if 'text_hidden_fcs.' in k}) to_return.update( {k: v for k, v in state_dict.items() if 'lm_head.weight' in k}) to_return.update( {k: v for k, v in state_dict.items() if 'embed_tokens.weight' in k}) return to_return def forward(self, data, data_samples=None, mode='loss'): if mode == 'loss': return self.compute_loss(data) elif mode == 'predict': return self.predict(data) elif mode == 'tensor': return self._forward(data) else: raise NotImplementedError def compute_loss(self,data, data_samples=None, mode='loss'): g_pixel_values = data.pop('g_pixel_values', None) gt_masks = data.pop('masks', None) input_ids = data['input_ids'] output = self.mllm(data, data_samples, mode) if gt_masks is None: g_pixel_values = [ torch.randn(3, 512, 1024).to(output.hidden_states[-1]) for _ in range(len(input_ids))] ori_size_list = [(512, 1024) for _ in range(len(input_ids))] seg_token_mask = torch.zeros_like(input_ids).bool() seg_token_mask[:, -2] = True else: ori_size_list = [mask.shape[-2:] for mask in gt_masks] seg_token_mask = input_ids == self.seg_token_idx resize_list = [pixel.shape[-2:] for pixel in g_pixel_values] g_pixel_values = torch.stack([ self.grounding_encoder.preprocess(pixel) for pixel in g_pixel_values ]) image_embeddings = self.grounding_encoder.image_encoder(g_pixel_values) seg_token_mask = seg_token_mask[:, 1:] seg_token_mask = torch.cat([ seg_token_mask, seg_token_mask.new_zeros(seg_token_mask.shape[0], 1)], dim=-1) hidden_states = output.hidden_states hidden_states = self.text_hidden_fcs(hidden_states[-1]) pred_embeddings = hidden_states[seg_token_mask] seg_token_counts = seg_token_mask.int().sum(-1) pred_embeddings_list = torch.split(pred_embeddings, seg_token_counts.tolist(), dim=0) pred_masks = self._generate_and_postprocess_masks( pred_embeddings_list, image_embeddings, resize_list, ori_size_list) if gt_masks is None: return { 'loss_mask': pred_masks[0].sum() * 0.0, 'loss_dice': pred_masks[0].sum() * 0.0, 'llm_loss': output.loss, } bs = len(pred_masks) loss_mask, loss_dice = 0, 0 for i in range(bs): pred_mask = pred_masks[i] gt_mask = gt_masks[i] sam_loss_mask = self.loss_mask(pred_mask, gt_mask) sam_loss_dice = self.loss_dice(pred_mask, gt_mask) accuracy = torch.eq((pred_mask.sigmoid() > 0.5), gt_mask).to(pred_mask).mean() loss_mask += sam_loss_mask loss_dice += sam_loss_dice loss_dict = { 'loss_mask': loss_mask / bs, 'loss_dice': loss_dice / bs, 'llm_loss': output.loss, } return loss_dict def predict(self, data): generation_config = dict(max_new_tokens=1024, do_sample=False) eos_token_id = self.tokenizer.convert_tokens_to_ids('<|end|>') generation_config['eos_token_id'] = eos_token_id pixel_values = data.pop('pixel_values') attention_mask = data.pop('attention_mask', None) input_ids = data['input_ids'] generate_output = self.mllm.generate( pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict_in_generate=True, **generation_config, ) device = self.mllm.model.device hidden_states = generate_output.hidden_states last_hidden_states = [item[-1] for item in hidden_states[1:]] # remove input_ids last_hidden_states = torch.cat(last_hidden_states, dim=1) last_hidden_states = last_hidden_states[0] # remove batch dim output_ids = generate_output.sequences[0][:-1] # remove batch dim and eos token output_text = self.tokenizer.decode(output_ids) seg_mask = output_ids == self.seg_token_idx if seg_mask.sum() == 0: return dict( pred_mask_logits=None, output_text=output_text, ) seg_embeds = self.text_hidden_fcs(last_hidden_states[seg_mask]) g_pixel_values = data.pop('g_pixel_values', None) gt_masks = data['masks'] ori_size_list = [mask.shape[-2:] for mask in gt_masks] resize_list = [pixel.shape[-2:] for pixel in g_pixel_values] g_pixel_values = torch.stack([ self.grounding_encoder.preprocess(pixel.to(device)) for pixel in g_pixel_values ]) image_embeddings = self.grounding_encoder.image_encoder(g_pixel_values) pred_masks = self._generate_and_postprocess_masks( [seg_embeds], image_embeddings, resize_list, ori_size_list) return dict( pred_mask_logits=pred_masks[0], # remove batch dim output_text=output_text, ) def gradient_checkpointing_enable(self): self.activation_checkpointing_enable() def activation_checkpointing_enable(self): self.mllm.model.language_model.gradient_checkpointing_enable() def gradient_checkpointing_disable(self): self.activation_checkpointing_disable() def activation_checkpointing_disable(self): self.mllm.model.language_model.gradient_checkpointing_disable()