|
|
|
|
|
|
|
import pickle as pkl |
|
import sys |
|
import torch |
|
|
|
""" |
|
Usage: |
|
# download one of the ResNet{18,34,50,101,152} models from torchvision: |
|
wget https://download.pytorch.org/models/resnet50-19c8e357.pth -O r50.pth |
|
# run the conversion |
|
./convert-torchvision-to-d2.py r50.pth r50.pkl |
|
|
|
# Then, use r50.pkl with the following changes in config: |
|
|
|
MODEL: |
|
WEIGHTS: "/path/to/r50.pkl" |
|
PIXEL_MEAN: [123.675, 116.280, 103.530] |
|
PIXEL_STD: [58.395, 57.120, 57.375] |
|
RESNETS: |
|
DEPTH: 50 |
|
STRIDE_IN_1X1: False |
|
INPUT: |
|
FORMAT: "RGB" |
|
|
|
These models typically produce slightly worse results than the |
|
pre-trained ResNets we use in official configs, which are the |
|
original ResNet models released by MSRA. |
|
""" |
|
|
|
if __name__ == "__main__": |
|
input = sys.argv[1] |
|
|
|
obj = torch.load(input, map_location="cpu") |
|
|
|
newmodel = {} |
|
for k in list(obj.keys()): |
|
old_k = k |
|
if "layer" not in k: |
|
k = "stem." + k |
|
for t in [1, 2, 3, 4]: |
|
k = k.replace("layer{}".format(t), "res{}".format(t + 1)) |
|
for t in [1, 2, 3]: |
|
k = k.replace("bn{}".format(t), "conv{}.norm".format(t)) |
|
k = k.replace("downsample.0", "shortcut") |
|
k = k.replace("downsample.1", "shortcut.norm") |
|
print(old_k, "->", k) |
|
newmodel[k] = obj.pop(old_k).detach().numpy() |
|
|
|
res = {"model": newmodel, "__author__": "torchvision", "matching_heuristics": True} |
|
|
|
with open(sys.argv[2], "wb") as f: |
|
pkl.dump(res, f) |
|
if obj: |
|
print("Unconverted keys:", obj.keys()) |
|
|