Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import os | |
import random | |
import subprocess | |
from urllib.parse import urlparse | |
import numpy as np | |
import torch | |
from torch import nn | |
logger = logging.getLogger("dinov2") | |
def load_pretrained_weights(model, pretrained_weights, checkpoint_key): | |
if urlparse(pretrained_weights).scheme: # If it looks like an URL | |
state_dict = torch.hub.load_state_dict_from_url( | |
pretrained_weights, map_location="cpu" | |
) | |
else: | |
state_dict = torch.load(pretrained_weights, map_location="cpu") | |
if checkpoint_key is not None and checkpoint_key in state_dict: | |
logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") | |
state_dict = state_dict[checkpoint_key] | |
# remove `module.` prefix | |
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} | |
# remove `backbone.` prefix induced by multicrop wrapper | |
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} | |
msg = model.load_state_dict(state_dict, strict=False) | |
logger.info( | |
"Pretrained weights found at {} and loaded with msg: {}".format( | |
pretrained_weights, msg | |
) | |
) | |
def fix_random_seeds(seed=31): | |
""" | |
Fix random seeds. | |
""" | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
def get_sha(): | |
cwd = os.path.dirname(os.path.abspath(__file__)) | |
def _run(command): | |
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() | |
sha = "N/A" | |
diff = "clean" | |
branch = "N/A" | |
try: | |
sha = _run(["git", "rev-parse", "HEAD"]) | |
subprocess.check_output(["git", "diff"], cwd=cwd) | |
diff = _run(["git", "diff-index", "HEAD"]) | |
diff = "has uncommitted changes" if diff else "clean" | |
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) | |
except Exception: | |
pass | |
message = f"sha: {sha}, status: {diff}, branch: {branch}" | |
return message | |
class CosineScheduler(object): | |
def __init__( | |
self, | |
base_value, | |
final_value, | |
total_iters, | |
warmup_iters=0, | |
start_warmup_value=0, | |
freeze_iters=0, | |
): | |
super().__init__() | |
self.final_value = final_value | |
self.total_iters = total_iters | |
freeze_schedule = np.zeros((freeze_iters)) | |
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) | |
iters = np.arange(total_iters - warmup_iters - freeze_iters) | |
schedule = final_value + 0.5 * (base_value - final_value) * ( | |
1 + np.cos(np.pi * iters / len(iters)) | |
) | |
self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) | |
assert len(self.schedule) == self.total_iters | |
def __getitem__(self, it): | |
if it >= self.total_iters: | |
return self.final_value | |
else: | |
return self.schedule[it] | |
def has_batchnorms(model): | |
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) | |
for name, module in model.named_modules(): | |
if isinstance(module, bn_types): | |
return True | |
return False | |