The keys names are not same to the model state dict keys in PATH

#1
by luomingshuang - opened

Hi, thanks for your great job. I am trying to reproduce your results based on this url pretrained weights. But when I load the pretrained weights model, I find the keys are not same. The pretrained weights keys are as follows:

odict_keys(['cls_token', 'cls_token_pos_embed', 'pos_embed', 'patch_embed.proj.weight', 'patch_embed.proj.bias', 'blocks.0.norm1.weight', 'blocks.0.norm1.bias', 'blocks.0.attn.qkv.weight', 'blocks.0.attn.qkv.bias', 'blocks.0.attn.proj.weight', 'blocks.0.attn.proj.bias', 'blocks.0.norm2.weight', 'blocks.0.norm2.bias', 'blocks.0.mlp.fc1.weight', 'blocks.0.mlp.fc1.bias', 'blocks.0.mlp.fc2.weight', 'blocks.0.mlp.fc2.bias', 'blocks.1.norm1.weight', 'blocks.1.norm1.bias', 'blocks.1.attn.qkv.weight', 'blocks.1.attn.qkv.bias', 'blocks.1.attn.proj.weight', 'blocks.1.attn.proj.bias', 'blocks.1.norm2.weight', 'blocks.1.norm2.bias', 'blocks.1.mlp.fc1.weight', 'blocks.1.mlp.fc1.bias', 'blocks.1.mlp.fc2.weight', 'blocks.1.mlp.fc2.bias', 'blocks.2.norm1.weight', 'blocks.2.norm1.bias', 'blocks.2.attn.qkv.weight', 'blocks.2.attn.qkv.bias', 'blocks.2.attn.proj.weight', 'blocks.2.attn.proj.bias', 'blocks.2.norm2.weight', 'blocks.2.norm2.bias', 'blocks.2.mlp.fc1.weight', 'blocks.2.mlp.fc1.bias', 'blocks.2.mlp.fc2.weight', 'blocks.2.mlp.fc2.bias', 'blocks.3.norm1.weight', 'blocks.3.norm1.bias', 'blocks.3.attn.qkv.weight', 'blocks.3.attn.qkv.bias', 'blocks.3.attn.proj.weight', 'blocks.3.attn.proj.bias', 'blocks.3.norm2.weight', 'blocks.3.norm2.bias', 'blocks.3.mlp.fc1.weight', 'blocks.3.mlp.fc1.bias', 'blocks.3.mlp.fc2.weight', 'blocks.3.mlp.fc2.bias', 'blocks.4.norm1.weight', 'blocks.4.norm1.bias', 'blocks.4.attn.qkv.weight', 'blocks.4.attn.qkv.bias', 'blocks.4.attn.proj.weight', 'blocks.4.attn.proj.bias', 'blocks.4.norm2.weight', 'blocks.4.norm2.bias', 'blocks.4.mlp.fc1.weight', 'blocks.4.mlp.fc1.bias', 'blocks.4.mlp.fc2.weight', 'blocks.4.mlp.fc2.bias', 'blocks.5.norm1.weight', 'blocks.5.norm1.bias', 'blocks.5.attn.qkv.weight', 'blocks.5.attn.qkv.bias', 'blocks.5.attn.proj.weight', 'blocks.5.attn.proj.bias', 'blocks.5.norm2.weight', 'blocks.5.norm2.bias', 'blocks.5.mlp.fc1.weight', 'blocks.5.mlp.fc1.bias', 'blocks.5.mlp.fc2.weight', 'blocks.5.mlp.fc2.bias', 'blocks.6.norm1.weight', 'blocks.6.norm1.bias', 'blocks.6.attn.qkv.weight', 'blocks.6.attn.qkv.bias', 'blocks.6.attn.proj.weight', 'blocks.6.attn.proj.bias', 'blocks.6.norm2.weight', 'blocks.6.norm2.bias', 'blocks.6.mlp.fc1.weight', 'blocks.6.mlp.fc1.bias', 'blocks.6.mlp.fc2.weight', 'blocks.6.mlp.fc2.bias', 'blocks.7.norm1.weight', 'blocks.7.norm1.bias', 'blocks.7.attn.qkv.weight', 'blocks.7.attn.qkv.bias', 'blocks.7.attn.proj.weight', 'blocks.7.attn.proj.bias', 'blocks.7.norm2.weight', 'blocks.7.norm2.bias', 'blocks.7.mlp.fc1.weight', 'blocks.7.mlp.fc1.bias', 'blocks.7.mlp.fc2.weight', 'blocks.7.mlp.fc2.bias', 'blocks.8.norm1.weight', 'blocks.8.norm1.bias', 'blocks.8.attn.qkv.weight', 'blocks.8.attn.qkv.bias', 'blocks.8.attn.proj.weight', 'blocks.8.attn.proj.bias', 'blocks.8.norm2.weight', 'blocks.8.norm2.bias', 'blocks.8.mlp.fc1.weight', 'blocks.8.mlp.fc1.bias', 'blocks.8.mlp.fc2.weight', 'blocks.8.mlp.fc2.bias', 'blocks.9.norm1.weight', 'blocks.9.norm1.bias', 'blocks.9.attn.qkv.weight', 'blocks.9.attn.qkv.bias', 'blocks.9.attn.proj.weight', 'blocks.9.attn.proj.bias', 'blocks.9.norm2.weight', 'blocks.9.norm2.bias', 'blocks.9.mlp.fc1.weight', 'blocks.9.mlp.fc1.bias', 'blocks.9.mlp.fc2.weight', 'blocks.9.mlp.fc2.bias', 'blocks.10.norm1.weight', 'blocks.10.norm1.bias', 'blocks.10.attn.qkv.weight', 'blocks.10.attn.qkv.bias', 'blocks.10.attn.proj.weight', 'blocks.10.attn.proj.bias', 'blocks.10.norm2.weight', 'blocks.10.norm2.bias', 'blocks.10.mlp.fc1.weight', 'blocks.10.mlp.fc1.bias', 'blocks.10.mlp.fc2.weight', 'blocks.10.mlp.fc2.bias', 'blocks.11.norm1.weight', 'blocks.11.norm1.bias', 'blocks.11.attn.qkv.weight', 'blocks.11.attn.qkv.bias', 'blocks.11.attn.proj.weight', 'blocks.11.attn.proj.bias', 'blocks.11.norm2.weight', 'blocks.11.norm2.bias', 'blocks.11.mlp.fc1.weight', 'blocks.11.mlp.fc1.bias', 'blocks.11.mlp.fc2.weight', 'blocks.11.mlp.fc2.bias', 'blocks.12.norm1.weight', 'blocks.12.norm1.bias', 'blocks.12.attn.qkv.weight', 'blocks.12.attn.qkv.bias', 'blocks.12.attn.proj.weight', 'blocks.12.attn.proj.bias', 'blocks.12.norm2.weight', 'blocks.12.norm2.bias', 'blocks.12.mlp.fc1.weight', 'blocks.12.mlp.fc1.bias', 'blocks.12.mlp.fc2.weight', 'blocks.12.mlp.fc2.bias', 'blocks.13.norm1.weight', 'blocks.13.norm1.bias', 'blocks.13.attn.qkv.weight', 'blocks.13.attn.qkv.bias', 'blocks.13.attn.proj.weight', 'blocks.13.attn.proj.bias', 'blocks.13.norm2.weight', 'blocks.13.norm2.bias', 'blocks.13.mlp.fc1.weight', 'blocks.13.mlp.fc1.bias', 'blocks.13.mlp.fc2.weight', 'blocks.13.mlp.fc2.bias', 'blocks.14.norm1.weight', 'blocks.14.norm1.bias', 'blocks.14.attn.qkv.weight', 'blocks.14.attn.qkv.bias', 'blocks.14.attn.proj.weight', 'blocks.14.attn.proj.bias', 'blocks.14.norm2.weight', 'blocks.14.norm2.bias', 'blocks.14.mlp.fc1.weight', 'blocks.14.mlp.fc1.bias', 'blocks.14.mlp.fc2.weight', 'blocks.14.mlp.fc2.bias', 'blocks.15.norm1.weight', 'blocks.15.norm1.bias', 'blocks.15.attn.qkv.weight', 'blocks.15.attn.qkv.bias', 'blocks.15.attn.proj.weight', 'blocks.15.attn.proj.bias', 'blocks.15.norm2.weight', 'blocks.15.norm2.bias', 'blocks.15.mlp.fc1.weight', 'blocks.15.mlp.fc1.bias', 'blocks.15.mlp.fc2.weight', 'blocks.15.mlp.fc2.bias', 'blocks.16.norm1.weight', 'blocks.16.norm1.bias', 'blocks.16.attn.qkv.weight', 'blocks.16.attn.qkv.bias', 'blocks.16.attn.proj.weight', 'blocks.16.attn.proj.bias', 'blocks.16.norm2.weight', 'blocks.16.norm2.bias', 'blocks.16.mlp.fc1.weight', 'blocks.16.mlp.fc1.bias', 'blocks.16.mlp.fc2.weight', 'blocks.16.mlp.fc2.bias', 'blocks.17.norm1.weight', 'blocks.17.norm1.bias', 'blocks.17.attn.qkv.weight', 'blocks.17.attn.qkv.bias', 'blocks.17.attn.proj.weight', 'blocks.17.attn.proj.bias', 'blocks.17.norm2.weight', 'blocks.17.norm2.bias', 'blocks.17.mlp.fc1.weight', 'blocks.17.mlp.fc1.bias', 'blocks.17.mlp.fc2.weight', 'blocks.17.mlp.fc2.bias', 'blocks.18.norm1.weight', 'blocks.18.norm1.bias', 'blocks.18.attn.qkv.weight', 'blocks.18.attn.qkv.bias', 'blocks.18.attn.proj.weight', 'blocks.18.attn.proj.bias', 'blocks.18.norm2.weight', 'blocks.18.norm2.bias', 'blocks.18.mlp.fc1.weight', 'blocks.18.mlp.fc1.bias', 'blocks.18.mlp.fc2.weight', 'blocks.18.mlp.fc2.bias', 'blocks.19.norm1.weight', 'blocks.19.norm1.bias', 'blocks.19.attn.qkv.weight', 'blocks.19.attn.qkv.bias', 'blocks.19.attn.proj.weight', 'blocks.19.attn.proj.bias', 'blocks.19.norm2.weight', 'blocks.19.norm2.bias', 'blocks.19.mlp.fc1.weight', 'blocks.19.mlp.fc1.bias', 'blocks.19.mlp.fc2.weight', 'blocks.19.mlp.fc2.bias', 'blocks.20.norm1.weight', 'blocks.20.norm1.bias', 'blocks.20.attn.qkv.weight', 'blocks.20.attn.qkv.bias', 'blocks.20.attn.proj.weight', 'blocks.20.attn.proj.bias', 'blocks.20.norm2.weight', 'blocks.20.norm2.bias', 'blocks.20.mlp.fc1.weight', 'blocks.20.mlp.fc1.bias', 'blocks.20.mlp.fc2.weight', 'blocks.20.mlp.fc2.bias', 'blocks.21.norm1.weight', 'blocks.21.norm1.bias', 'blocks.21.attn.qkv.weight', 'blocks.21.attn.qkv.bias', 'blocks.21.attn.proj.weight', 'blocks.21.attn.proj.bias', 'blocks.21.norm2.weight', 'blocks.21.norm2.bias', 'blocks.21.mlp.fc1.weight', 'blocks.21.mlp.fc1.bias', 'blocks.21.mlp.fc2.weight', 'blocks.21.mlp.fc2.bias', 'blocks.22.norm1.weight', 'blocks.22.norm1.bias', 'blocks.22.attn.qkv.weight', 'blocks.22.attn.qkv.bias', 'blocks.22.attn.proj.weight', 'blocks.22.attn.proj.bias', 'blocks.22.norm2.weight', 'blocks.22.norm2.bias', 'blocks.22.mlp.fc1.weight', 'blocks.22.mlp.fc1.bias', 'blocks.22.mlp.fc2.weight', 'blocks.22.mlp.fc2.bias', 'blocks.23.norm1.weight', 'blocks.23.norm1.bias', 'blocks.23.attn.qkv.weight', 'blocks.23.attn.qkv.bias', 'blocks.23.attn.proj.weight', 'blocks.23.attn.proj.bias', 'blocks.23.norm2.weight', 'blocks.23.norm2.bias', 'blocks.23.mlp.fc1.weight', 'blocks.23.mlp.fc1.bias', 'blocks.23.mlp.fc2.weight', 'blocks.23.mlp.fc2.bias', 'norm.weight', 'norm.bias'])

But the keys for the model in PATH are as follows:

odict_keys(['module.backbone_module.pos_embed', 'module.backbone_module.patch_embed.proj.weight', 'module.backbone_module.patch_embed.proj.bias', 'module.backbone_module.blocks.0.norm1.weight', 'module.backbone_module.blocks.0.norm1.bias', 'module.backbone_module.blocks.0.attn.qkv.weight', 'module.backbone_module.blocks.0.attn.qkv.bias', 'module.backbone_module.blocks.0.attn.proj.weight', 'module.backbone_module.blocks.0.attn.proj.bias', 'module.backbone_module.blocks.0.norm2.weight', 'module.backbone_module.blocks.0.norm2.bias', 'module.backbone_module.blocks.0.mlp.fc1.weight', 'module.backbone_module.blocks.0.mlp.fc1.bias', 'module.backbone_module.blocks.0.mlp.fc2.weight', 'module.backbone_module.blocks.0.mlp.fc2.bias', 'module.backbone_module.blocks.1.norm1.weight', 'module.backbone_module.blocks.1.norm1.bias', 'module.backbone_module.blocks.1.attn.qkv.weight', 'module.backbone_module.blocks.1.attn.qkv.bias', 'module.backbone_module.blocks.1.attn.proj.weight', 'module.backbone_module.blocks.1.attn.proj.bias', 'module.backbone_module.blocks.1.norm2.weight', 'module.backbone_module.blocks.1.norm2.bias', 'module.backbone_module.blocks.1.mlp.fc1.weight', 'module.backbone_module.blocks.1.mlp.fc1.bias', 'module.backbone_module.blocks.1.mlp.fc2.weight', 'module.backbone_module.blocks.1.mlp.fc2.bias', 'module.backbone_module.blocks.2.norm1.weight', 'module.backbone_module.blocks.2.norm1.bias', 'module.backbone_module.blocks.2.attn.qkv.weight', 'module.backbone_module.blocks.2.attn.qkv.bias', 'module.backbone_module.blocks.2.attn.proj.weight', 'module.backbone_module.blocks.2.attn.proj.bias', 'module.backbone_module.blocks.2.norm2.weight', 'module.backbone_module.blocks.2.norm2.bias', 'module.backbone_module.blocks.2.mlp.fc1.weight', 'module.backbone_module.blocks.2.mlp.fc1.bias', 'module.backbone_module.blocks.2.mlp.fc2.weight', 'module.backbone_module.blocks.2.mlp.fc2.bias', 'module.backbone_module.blocks.3.norm1.weight', 'module.backbone_module.blocks.3.norm1.bias', 'module.backbone_module.blocks.3.attn.qkv.weight', 'module.backbone_module.blocks.3.attn.qkv.bias', 'module.backbone_module.blocks.3.attn.proj.weight', 'module.backbone_module.blocks.3.attn.proj.bias', 'module.backbone_module.blocks.3.norm2.weight', 'module.backbone_module.blocks.3.norm2.bias', 'module.backbone_module.blocks.3.mlp.fc1.weight', 'module.backbone_module.blocks.3.mlp.fc1.bias', 'module.backbone_module.blocks.3.mlp.fc2.weight', 'module.backbone_module.blocks.3.mlp.fc2.bias', 'module.backbone_module.blocks.4.norm1.weight', 'module.backbone_module.blocks.4.norm1.bias', 'module.backbone_module.blocks.4.attn.qkv.weight', 'module.backbone_module.blocks.4.attn.qkv.bias', 'module.backbone_module.blocks.4.attn.proj.weight', 'module.backbone_module.blocks.4.attn.proj.bias', 'module.backbone_module.blocks.4.norm2.weight', 'module.backbone_module.blocks.4.norm2.bias', 'module.backbone_module.blocks.4.mlp.fc1.weight', 'module.backbone_module.blocks.4.mlp.fc1.bias', 'module.backbone_module.blocks.4.mlp.fc2.weight', 'module.backbone_module.blocks.4.mlp.fc2.bias', 'module.backbone_module.blocks.5.norm1.weight', 'module.backbone_module.blocks.5.norm1.bias', 'module.backbone_module.blocks.5.attn.qkv.weight', 'module.backbone_module.blocks.5.attn.qkv.bias', 'module.backbone_module.blocks.5.attn.proj.weight', 'module.backbone_module.blocks.5.attn.proj.bias', 'module.backbone_module.blocks.5.norm2.weight', 'module.backbone_module.blocks.5.norm2.bias', 'module.backbone_module.blocks.5.mlp.fc1.weight', 'module.backbone_module.blocks.5.mlp.fc1.bias', 'module.backbone_module.blocks.5.mlp.fc2.weight', 'module.backbone_module.blocks.5.mlp.fc2.bias', 'module.backbone_module.blocks.6.norm1.weight', 'module.backbone_module.blocks.6.norm1.bias', 'module.backbone_module.blocks.6.attn.qkv.weight', 'module.backbone_module.blocks.6.attn.qkv.bias', 'module.backbone_module.blocks.6.attn.proj.weight', 'module.backbone_module.blocks.6.attn.proj.bias', 'module.backbone_module.blocks.6.norm2.weight', 'module.backbone_module.blocks.6.norm2.bias', 'module.backbone_module.blocks.6.mlp.fc1.weight', 'module.backbone_module.blocks.6.mlp.fc1.bias', 'module.backbone_module.blocks.6.mlp.fc2.weight', 'module.backbone_module.blocks.6.mlp.fc2.bias', 'module.backbone_module.blocks.7.norm1.weight', 'module.backbone_module.blocks.7.norm1.bias', 'module.backbone_module.blocks.7.attn.qkv.weight', 'module.backbone_module.blocks.7.attn.qkv.bias', 'module.backbone_module.blocks.7.attn.proj.weight', 'module.backbone_module.blocks.7.attn.proj.bias', 'module.backbone_module.blocks.7.norm2.weight', 'module.backbone_module.blocks.7.norm2.bias', 'module.backbone_module.blocks.7.mlp.fc1.weight', 'module.backbone_module.blocks.7.mlp.fc1.bias', 'module.backbone_module.blocks.7.mlp.fc2.weight', 'module.backbone_module.blocks.7.mlp.fc2.bias', 'module.backbone_module.blocks.8.norm1.weight', 'module.backbone_module.blocks.8.norm1.bias', 'module.backbone_module.blocks.8.attn.qkv.weight', 'module.backbone_module.blocks.8.attn.qkv.bias', 'module.backbone_module.blocks.8.attn.proj.weight', 'module.backbone_module.blocks.8.attn.proj.bias', 'module.backbone_module.blocks.8.norm2.weight', 'module.backbone_module.blocks.8.norm2.bias', 'module.backbone_module.blocks.8.mlp.fc1.weight', 'module.backbone_module.blocks.8.mlp.fc1.bias', 'module.backbone_module.blocks.8.mlp.fc2.weight', 'module.backbone_module.blocks.8.mlp.fc2.bias', 'module.backbone_module.blocks.9.norm1.weight', 'module.backbone_module.blocks.9.norm1.bias', 'module.backbone_module.blocks.9.attn.qkv.weight', 'module.backbone_module.blocks.9.attn.qkv.bias', 'module.backbone_module.blocks.9.attn.proj.weight', 'module.backbone_module.blocks.9.attn.proj.bias', 'module.backbone_module.blocks.9.norm2.weight', 'module.backbone_module.blocks.9.norm2.bias', 'module.backbone_module.blocks.9.mlp.fc1.weight', 'module.backbone_module.blocks.9.mlp.fc1.bias', 'module.backbone_module.blocks.9.mlp.fc2.weight', 'module.backbone_module.blocks.9.mlp.fc2.bias', 'module.backbone_module.blocks.10.norm1.weight', 'module.backbone_module.blocks.10.norm1.bias', 'module.backbone_module.blocks.10.attn.qkv.weight', 'module.backbone_module.blocks.10.attn.qkv.bias', 'module.backbone_module.blocks.10.attn.proj.weight', 'module.backbone_module.blocks.10.attn.proj.bias', 'module.backbone_module.blocks.10.norm2.weight', 'module.backbone_module.blocks.10.norm2.bias', 'module.backbone_module.blocks.10.mlp.fc1.weight', 'module.backbone_module.blocks.10.mlp.fc1.bias', 'module.backbone_module.blocks.10.mlp.fc2.weight', 'module.backbone_module.blocks.10.mlp.fc2.bias', 'module.backbone_module.blocks.11.norm1.weight', 'module.backbone_module.blocks.11.norm1.bias', 'module.backbone_module.blocks.11.attn.qkv.weight', 'module.backbone_module.blocks.11.attn.qkv.bias', 'module.backbone_module.blocks.11.attn.proj.weight', 'module.backbone_module.blocks.11.attn.proj.bias', 'module.backbone_module.blocks.11.norm2.weight', 'module.backbone_module.blocks.11.norm2.bias', 'module.backbone_module.blocks.11.mlp.fc1.weight', 'module.backbone_module.blocks.11.mlp.fc1.bias', 'module.backbone_module.blocks.11.mlp.fc2.weight', 'module.backbone_module.blocks.11.mlp.fc2.bias', 'module.backbone_module.norm.weight', 'module.backbone_module.norm.bias', 'module.neck_module.reduction_layers.0.weight', 'module.neck_module.reduction_layers.0.bias', 'module.neck_module.reduction_layers.1.weight', 'module.neck_module.reduction_layers.1.bias', 'module.neck_module.reduction_layers.2.weight', 'module.neck_module.reduction_layers.2.bias', 'module.neck_module.reduction_layers.3.weight', 'module.neck_module.reduction_layers.3.bias', 'module.neck_module.reduction_layers.4.weight', 'module.neck_module.reduction_layers.4.bias', 'module.neck_module.reduction_layers.5.weight', 'module.neck_module.reduction_layers.5.bias', 'module.neck_module.reduction_layers.6.weight', 'module.neck_module.reduction_layers.6.bias', 'module.neck_module.reduction_layers.7.weight', 'module.neck_module.reduction_layers.7.bias', 'module.neck_module.reduction_layers.8.weight', 'module.neck_module.reduction_layers.8.bias', 'module.neck_module.reduction_layers.9.weight', 'module.neck_module.reduction_layers.9.bias', 'module.neck_module.reduction_layers.10.weight', 'module.neck_module.reduction_layers.10.bias', 'module.neck_module.reduction_layers.11.weight', 'module.neck_module.reduction_layers.11.bias', 'module.neck_module.reduction_layers.12.weight', 'module.neck_module.reduction_layers.12.bias', 'module.neck_module.side_gate_params.0', 'module.neck_module.side_gate_params.1', 'module.neck_module.side_gate_params.2', 'module.neck_module.side_gate_params.3', 'module.neck_module.side_gate_params.4', 'module.neck_module.side_gate_params.5', 'module.neck_module.side_gate_params.6', 'module.neck_module.side_gate_params.7', 'module.neck_module.side_gate_params.8', 'module.neck_module.side_gate_params.9', 'module.neck_module.side_gate_params.10', 'module.neck_module.side_gate_params.11', 'module.neck_module.transformer_blocks.0.0.norm1.weight', 'module.neck_module.transformer_blocks.0.0.norm1.bias', 'module.neck_module.transformer_blocks.0.0.attn.qkv.weight', 'module.neck_module.transformer_blocks.0.0.attn.proj.weight', 'module.neck_module.transformer_blocks.0.0.attn.proj.bias', 'module.neck_module.transformer_blocks.0.0.norm2.weight', 'module.neck_module.transformer_blocks.0.0.norm2.bias', 'module.neck_module.transformer_blocks.0.0.mlp.fc1.weight', 'module.neck_module.transformer_blocks.0.0.mlp.fc1.bias', 'module.neck_module.transformer_blocks.0.0.mlp.fc2.weight', 'module.neck_module.transformer_blocks.0.0.mlp.fc2.bias', 'module.neck_module.transformer_blocks.1.0.norm1.weight', 'module.neck_module.transformer_blocks.1.0.norm1.bias', 'module.neck_module.transformer_blocks.1.0.attn.qkv.weight', 'module.neck_module.transformer_blocks.1.0.attn.proj.weight', 'module.neck_module.transformer_blocks.1.0.attn.proj.bias', 'module.neck_module.transformer_blocks.1.0.norm2.weight', 'module.neck_module.transformer_blocks.1.0.norm2.bias', 'module.neck_module.transformer_blocks.1.0.mlp.fc1.weight', 'module.neck_module.transformer_blocks.1.0.mlp.fc1.bias', 'module.neck_module.transformer_blocks.1.0.mlp.fc2.weight', 'module.neck_module.transformer_blocks.1.0.mlp.fc2.bias', 'module.neck_module.transformer_blocks.2.0.norm1.weight', 'module.neck_module.transformer_blocks.2.0.norm1.bias', 'module.neck_module.transformer_blocks.2.0.attn.qkv.weight', 'module.neck_module.transformer_blocks.2.0.attn.proj.weight', 'module.neck_module.transformer_blocks.2.0.attn.proj.bias', 'module.neck_module.transformer_blocks.2.0.norm2.weight', 'module.neck_module.transformer_blocks.2.0.norm2.bias', 'module.neck_module.transformer_blocks.2.0.mlp.fc1.weight', 'module.neck_module.transformer_blocks.2.0.mlp.fc1.bias', 'module.neck_module.transformer_blocks.2.0.mlp.fc2.weight', 'module.neck_module.transformer_blocks.2.0.mlp.fc2.bias', 'module.neck_module.transformer_blocks.3.0.norm1.weight', 'module.neck_module.transformer_blocks.3.0.norm1.bias', 'module.neck_module.transformer_blocks.3.0.attn.qkv.weight', 'module.neck_module.transformer_blocks.3.0.attn.proj.weight', 'module.neck_module.transformer_blocks.3.0.attn.proj.bias', 'module.neck_module.transformer_blocks.3.0.norm2.weight', 'module.neck_module.transformer_blocks.3.0.norm2.bias', 'module.neck_module.transformer_blocks.3.0.mlp.fc1.weight', 'module.neck_module.transformer_blocks.3.0.mlp.fc1.bias', 'module.neck_module.transformer_blocks.3.0.mlp.fc2.weight', 'module.neck_module.transformer_blocks.3.0.mlp.fc2.bias', 'module.neck_module.transformer_blocks.4.0.norm1.weight', 'module.neck_module.transformer_blocks.4.0.norm1.bias', 'module.neck_module.transformer_blocks.4.0.attn.qkv.weight', 'module.neck_module.transformer_blocks.4.0.attn.proj.weight', 'module.neck_module.transformer_blocks.4.0.attn.proj.bias', 'module.neck_module.transformer_blocks.4.0.norm2.weight', 'module.neck_module.transformer_blocks.4.0.norm2.bias', 'module.neck_module.transformer_blocks.4.0.mlp.fc1.weight', 'module.neck_module.transformer_blocks.4.0.mlp.fc1.bias', 'module.neck_module.transformer_blocks.4.0.mlp.fc2.weight', 'module.neck_module.transformer_blocks.4.0.mlp.fc2.bias', 'module.neck_module.transformer_blocks.5.0.norm1.weight', 'module.neck_module.transformer_blocks.5.0.norm1.bias', 'module.neck_module.transformer_blocks.5.0.attn.qkv.weight', 'module.neck_module.transformer_blocks.5.0.attn.proj.weight', 'module.neck_module.transformer_blocks.5.0.attn.proj.bias', 'module.neck_module.transformer_blocks.5.0.norm2.weight', 'module.neck_module.transformer_blocks.5.0.norm2.bias', 'module.neck_module.transformer_blocks.5.0.mlp.fc1.weight', 'module.neck_module.transformer_blocks.5.0.mlp.fc1.bias', 'module.neck_module.transformer_blocks.5.0.mlp.fc2.weight', 'module.neck_module.transformer_blocks.5.0.mlp.fc2.bias', 'module.neck_module.transformer_blocks.6.0.norm1.weight', 'module.neck_module.transformer_blocks.6.0.norm1.bias', 'module.neck_module.transformer_blocks.6.0.attn.qkv.weight', 'module.neck_module.transformer_blocks.6.0.attn.proj.weight', 'module.neck_module.transformer_blocks.6.0.attn.proj.bias', 'module.neck_module.transformer_blocks.6.0.norm2.weight', 'module.neck_module.transformer_blocks.6.0.norm2.bias', 'module.neck_module.transformer_blocks.6.0.mlp.fc1.weight', 'module.neck_module.transformer_blocks.6.0.mlp.fc1.bias', 'module.neck_module.transformer_blocks.6.0.mlp.fc2.weight', 'module.neck_module.transformer_blocks.6.0.mlp.fc2.bias', 'module.neck_module.transformer_blocks.7.0.norm1.weight', 'module.neck_module.transformer_blocks.7.0.norm1.bias', 'module.neck_module.transformer_blocks.7.0.attn.qkv.weight', 'module.neck_module.transformer_blocks.7.0.attn.proj.weight', 'module.neck_module.transformer_blocks.7.0.attn.proj.bias', 'module.neck_module.transformer_blocks.7.0.norm2.weight', 'module.neck_module.transformer_blocks.7.0.norm2.bias', 'module.neck_module.transformer_blocks.7.0.mlp.fc1.weight', 'module.neck_module.transformer_blocks.7.0.mlp.fc1.bias', 'module.neck_module.transformer_blocks.7.0.mlp.fc2.weight', 'module.neck_module.transformer_blocks.7.0.mlp.fc2.bias', 'module.neck_module.transformer_blocks.8.0.norm1.weight', 'module.neck_module.transformer_blocks.8.0.norm1.bias', 'module.neck_module.transformer_blocks.8.0.attn.qkv.weight', 'module.neck_module.transformer_blocks.8.0.attn.proj.weight', 'module.neck_module.transformer_blocks.8.0.attn.proj.bias', 'module.neck_module.transformer_blocks.8.0.norm2.weight', 'module.neck_module.transformer_blocks.8.0.norm2.bias', 'module.neck_module.transformer_blocks.8.0.mlp.fc1.weight', 'module.neck_module.transformer_blocks.8.0.mlp.fc1.bias', 'module.neck_module.transformer_blocks.8.0.mlp.fc2.weight', 'module.neck_module.transformer_blocks.8.0.mlp.fc2.bias', 'module.neck_module.transformer_blocks.9.0.norm1.weight', 'module.neck_module.transformer_blocks.9.0.norm1.bias', 'module.neck_module.transformer_blocks.9.0.attn.qkv.weight', 'module.neck_module.transformer_blocks.9.0.attn.proj.weight', 'module.neck_module.transformer_blocks.9.0.attn.proj.bias', 'module.neck_module.transformer_blocks.9.0.norm2.weight', 'module.neck_module.transformer_blocks.9.0.norm2.bias', 'module.neck_module.transformer_blocks.9.0.mlp.fc1.weight', 'module.neck_module.transformer_blocks.9.0.mlp.fc1.bias', 'module.neck_module.transformer_blocks.9.0.mlp.fc2.weight', 'module.neck_module.transformer_blocks.9.0.mlp.fc2.bias', 'module.neck_module.transformer_blocks.10.0.norm1.weight', 'module.neck_module.transformer_blocks.10.0.norm1.bias', 'module.neck_module.transformer_blocks.10.0.attn.qkv.weight', 'module.neck_module.transformer_blocks.10.0.attn.proj.weight', 'module.neck_module.transformer_blocks.10.0.attn.proj.bias', 'module.neck_module.transformer_blocks.10.0.norm2.weight', 'module.neck_module.transformer_blocks.10.0.norm2.bias', 'module.neck_module.transformer_blocks.10.0.mlp.fc1.weight', 'module.neck_module.transformer_blocks.10.0.mlp.fc1.bias', 'module.neck_module.transformer_blocks.10.0.mlp.fc2.weight', 'module.neck_module.transformer_blocks.10.0.mlp.fc2.bias', 'module.neck_module.transformer_blocks.11.0.norm1.weight', 'module.neck_module.transformer_blocks.11.0.norm1.bias', 'module.neck_module.transformer_blocks.11.0.attn.qkv.weight', 'module.neck_module.transformer_blocks.11.0.attn.proj.weight', 'module.neck_module.transformer_blocks.11.0.attn.proj.bias', 'module.neck_module.transformer_blocks.11.0.norm2.weight', 'module.neck_module.transformer_blocks.11.0.norm2.bias', 'module.neck_module.transformer_blocks.11.0.mlp.fc1.weight', 'module.neck_module.transformer_blocks.11.0.mlp.fc1.bias', 'module.neck_module.transformer_blocks.11.0.mlp.fc2.weight', 'module.neck_module.transformer_blocks.11.0.mlp.fc2.bias', 'module.neck_module.last_proj.weight', 'module.neck_module.last_proj.bias', 'module.decoder_module.logits.0.weight', 'module.decoder_module.logits.0.bias', 'module.decoder_module.logits.1.weight', 'module.decoder_module.logits.1.bias', 'module.decoder_module.logits.1.running_mean', 'module.decoder_module.logits.1.running_var', 'module.decoder_module.logits.1.num_batches_tracked'])

So the pretrained weights keys can not adapt to the model, can you give me some suggestions about how to load the pretrained weights for the PATH model?

Thanks a lot!!

OpenGVLab org

Thanks for your interests to our work!

import os
import sys
import collections
import torch
import numpy as np

Please use the transferred model (with_cls_token-reid) for ReID tasks and the transferred model (wo_cls_token-reid) for other tasks!

sys.path.append('/mnt/cache/chencheng1/vitruvian/vitruvian-multitask')

mae_pretrain_path = '/mnt/cache/chencheng1/vitruvian/vitruvian-multitask/core/models/backbones/pretrain_weights/mae_pretrain_vit_base.pth' # load mae model
mae_model = torch.load(mae_pretrain_path)

save_root = '/mnt/lustre/share_data/chencheng1/vitruvian/L2_final_base' # folder of save_path of the transferred model

root = '/mnt/lustre/chencheng1/expr_files/vitruvian/L2_full_setting_joint/checkpoints' # folder of model_path of the pretrained model
config_lists = [
'v100_32g_vitbase_size224_lr1e3_stepLRx3_bmp1_adafactor_wd01_clip05_layerdecay075_lpe_peddet_citypersons_LSA_reduct8_tbn1_heads2_gate1_peddetShareDecoder_exp3_setting_SharePosEmbed'
]

for config in config_lists:
trained_ckpt_root = os.path.join(root, config) # Please take note trained_ckpt_root is the place the pretrained model is saved
expr_name = trained_ckpt_root.split('/')[-1]

wo_cls_token_index = 0  # coco
with_cls_token_index = 20  # reid_4set


# with_cls_token-reid
with_cls_token_train_model_path = os.path.join(trained_ckpt_root, 'ckpt_task{}_iter_newest.pth.tar'.format(with_cls_token_index))
with_cls_token_transed_ckpt_save_path = os.path.join(save_root, 'with_cls_token', expr_name+'.pth')
with_cls_token_train_model = torch.load(with_cls_token_train_model_path, map_location=torch.device('cpu'))

cnt = 0
traned_ckpt = collections.OrderedDict()
for name, param in mae_model['model'].items():
    trained_model_name = 'module.backbone_module.' + name
    if trained_model_name in with_cls_token_train_model['state_dict']:
        if name == 'pos_embed':
            cnt += 1
            traned_ckpt[name] = torch.cat([with_cls_token_train_model['state_dict']['module.backbone_module.cls_token_pos_embed'], with_cls_token_train_model['state_dict'][trained_model_name]], dim=1)
        else:
            cnt += 1
            traned_ckpt[name] = with_cls_token_train_model['state_dict'][trained_model_name]
    else:
        traned_ckpt[name] = mae_model['model'][name]

torch.save({'model': traned_ckpt}, with_cls_token_transed_ckpt_save_path)
print('done! transed ckpt saved at: {}'.format(with_cls_token_transed_ckpt_save_path))


# wo_cls_token-reid
wo_cls_token_train_model_path = os.path.join(trained_ckpt_root, 'ckpt_task{}_iter_newest.pth.tar'.format(wo_cls_token_index))
wo_cls_token_transed_ckpt_save_path = os.path.join(save_root, 'wo_cls_token', expr_name+'.pth')
wo_cls_token_train_model = torch.load(wo_cls_token_train_model_path, map_location=torch.device('cpu'))

cnt = 0
traned_ckpt = collections.OrderedDict()
for name, param in mae_model['model'].items():
    trained_model_name = 'module.backbone_module.' + name
    if trained_model_name in wo_cls_token_train_model['state_dict']:
        cnt += 1
        traned_ckpt[name] = wo_cls_token_train_model['state_dict'][trained_model_name]
    else:
        traned_ckpt[name] = mae_model['model'][name]

torch.save({'model': traned_ckpt}, wo_cls_token_transed_ckpt_save_path)
print('done! transed ckpt saved at: {}'.format(wo_cls_token_transed_ckpt_save_path))

Sign up or log in to comment