Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
import pickle as pkl | |
import sys | |
import torch | |
""" | |
Usage: | |
# download one of the ResNet{18,34,50,101,152} models from torchvision: | |
wget https://download.pytorch.org/models/resnet50-19c8e357.pth -O r50.pth | |
# run the conversion | |
./convert-torchvision-to-d2.py r50.pth r50.pkl | |
# Then, use r50.pkl with the following changes in config: | |
MODEL: | |
WEIGHTS: "/path/to/r50.pkl" | |
PIXEL_MEAN: [123.675, 116.280, 103.530] | |
PIXEL_STD: [58.395, 57.120, 57.375] | |
RESNETS: | |
DEPTH: 50 | |
STRIDE_IN_1X1: False | |
INPUT: | |
FORMAT: "RGB" | |
""" | |
if __name__ == "__main__": | |
input = sys.argv[1] | |
obj = torch.load(input, map_location="cpu") | |
newmodel = {} | |
for k in list(obj.keys()): | |
old_k = k | |
if "layer" not in k: | |
k = "stem." + k | |
for t in [1, 2, 3, 4]: | |
k = k.replace("layer{}".format(t), "res{}".format(t + 1)) | |
for t in [1, 2, 3]: | |
k = k.replace("bn{}".format(t), "conv{}.norm".format(t)) | |
k = k.replace("downsample.0", "shortcut") | |
k = k.replace("downsample.1", "shortcut.norm") | |
print(old_k, "->", k) | |
newmodel[k] = obj.pop(old_k).detach().numpy() | |
res = {"model": newmodel, "__author__": "torchvision", "matching_heuristics": True} | |
with open(sys.argv[2], "wb") as f: | |
pkl.dump(res, f) | |
if obj: | |
print("Unconverted keys:", obj.keys()) | |