# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # pyre-unsafe import logging import numpy as np import pickle from enum import Enum from typing import Optional import torch from torch import nn from detectron2.config import CfgNode from detectron2.utils.file_io import PathManager from .vertex_direct_embedder import VertexDirectEmbedder from .vertex_feature_embedder import VertexFeatureEmbedder class EmbedderType(Enum): """ Embedder type which defines how vertices are mapped into the embedding space: - "vertex_direct": direct vertex embedding - "vertex_feature": embedding vertex features """ VERTEX_DIRECT = "vertex_direct" VERTEX_FEATURE = "vertex_feature" def create_embedder(embedder_spec: CfgNode, embedder_dim: int) -> nn.Module: """ Create an embedder based on the provided configuration Args: embedder_spec (CfgNode): embedder configuration embedder_dim (int): embedding space dimensionality Return: An embedder instance for the specified configuration Raises ValueError, in case of unexpected embedder type """ embedder_type = EmbedderType(embedder_spec.TYPE) if embedder_type == EmbedderType.VERTEX_DIRECT: embedder = VertexDirectEmbedder( num_vertices=embedder_spec.NUM_VERTICES, embed_dim=embedder_dim, ) if embedder_spec.INIT_FILE != "": embedder.load(embedder_spec.INIT_FILE) elif embedder_type == EmbedderType.VERTEX_FEATURE: embedder = VertexFeatureEmbedder( num_vertices=embedder_spec.NUM_VERTICES, feature_dim=embedder_spec.FEATURE_DIM, embed_dim=embedder_dim, train_features=embedder_spec.FEATURES_TRAINABLE, ) if embedder_spec.INIT_FILE != "": embedder.load(embedder_spec.INIT_FILE) else: raise ValueError(f"Unexpected embedder type {embedder_type}") if not embedder_spec.IS_TRAINABLE: embedder.requires_grad_(False) return embedder class Embedder(nn.Module): """ Embedder module that serves as a container for embedders to use with different meshes. Extends Module to automatically save / load state dict. """ DEFAULT_MODEL_CHECKPOINT_PREFIX = "roi_heads.embedder." def __init__(self, cfg: CfgNode): """ Initialize mesh embedders. An embedder for mesh `i` is stored in a submodule "embedder_{i}". Args: cfg (CfgNode): configuration options """ super(Embedder, self).__init__() self.mesh_names = set() embedder_dim = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_SIZE logger = logging.getLogger(__name__) for mesh_name, embedder_spec in cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDERS.items(): logger.info(f"Adding embedder embedder_{mesh_name} with spec {embedder_spec}") self.add_module(f"embedder_{mesh_name}", create_embedder(embedder_spec, embedder_dim)) self.mesh_names.add(mesh_name) if cfg.MODEL.WEIGHTS != "": self.load_from_model_checkpoint(cfg.MODEL.WEIGHTS) def load_from_model_checkpoint(self, fpath: str, prefix: Optional[str] = None): if prefix is None: prefix = Embedder.DEFAULT_MODEL_CHECKPOINT_PREFIX state_dict = None if fpath.endswith(".pkl"): with PathManager.open(fpath, "rb") as hFile: state_dict = pickle.load(hFile, encoding="latin1") else: with PathManager.open(fpath, "rb") as hFile: state_dict = torch.load(hFile, map_location=torch.device("cpu")) if state_dict is not None and "model" in state_dict: state_dict_local = {} for key in state_dict["model"]: if key.startswith(prefix): v_key = state_dict["model"][key] if isinstance(v_key, np.ndarray): v_key = torch.from_numpy(v_key) state_dict_local[key[len(prefix) :]] = v_key # non-strict loading to finetune on different meshes self.load_state_dict(state_dict_local, strict=False) def forward(self, mesh_name: str) -> torch.Tensor: """ Produce vertex embeddings for the specific mesh; vertex embeddings are a tensor of shape [N, D] where: N = number of vertices D = number of dimensions in the embedding space Args: mesh_name (str): name of a mesh for which to obtain vertex embeddings Return: Vertex embeddings, a tensor of shape [N, D] """ return getattr(self, f"embedder_{mesh_name}")() def has_embeddings(self, mesh_name: str) -> bool: return hasattr(self, f"embedder_{mesh_name}")