import torch import torch.nn as nn from huggingface_hub import hf_hub_download from .ppat_rgb import Projected, PointPatchTransformer from .Minkowski import MinkResNet34 def module(state_dict: dict, name): return {'.'.join(k.split('.')[1:]): v for k, v in state_dict.items() if k.startswith(name + '.')} def G14(s): model = Projected( PointPatchTransformer(512, 12, 8, 512*3, 256, 384, 0.2, 64, 6), nn.Linear(512, 1280) ) model.load_state_dict(module(s['state_dict'], 'module')) return model def L14(s): model = Projected( PointPatchTransformer(512, 12, 8, 1024, 128, 64, 0.4, 256, 6), nn.Linear(512, 768) ) model.load_state_dict(module(s, 'pc_encoder')) return model def B32(s): model = PointPatchTransformer(512, 12, 8, 1024, 128, 64, 0.4, 256, 6) model.load_state_dict(module(s, 'pc_encoder')) return model def Mk34(s): model = MinkResNet34() model.load_state_dict(module(s, 'pc_encoder')) return model model_list = { "openshape-pointbert-vitb32-rgb": B32, "openshape-pointbert-vitl14-rgb": L14, "openshape-pointbert-vitg14-rgb": G14, "tripletmix-spconv-all": Mk34, } def load_pc_encoder(name): s = torch.load(hf_hub_download("OpenShape/" + name, "model.pt"), map_location='cpu') model = model_list[name](s).eval() if torch.cuda.is_available(): model.cuda() return model def load_pc_encoder_mix(name): s = torch.load(hf_hub_download("TripletMix/" + name, "model.pt"), map_location='cpu') model = model_list[name](s).eval() if torch.cuda.is_available(): model.cuda() return model